In [1]:
import logging
import random
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from sklearn.metrics import classification_report, confusion_matrix
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
from tqdm.notebook import tqdm
from transformers import ViTForImageClassification, ViTImageProcessor

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger()

class Config:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model_name = "Falconsai/nsfw_image_detection"
        self.num_classes = 2
        self.batch_size = 16
        self.num_epochs = 6 # 10
        self.learning_rate = 5e-6 # 2e-5
        self.weight_decay = 0.01
        self.train_ratio = 0.7
        self.val_ratio = 0.15
        self.num_workers = 4
        self.model_save_path = "./models/nsfw_vit_classifier"
        self.nsfw_dir = "./newdata/nsfw"
        self.sfw_humans_dir = "./newdata/sfw-human"
        self.sfw_anime_dir = "./newdata/sfw-anime"
        self.class_names = ["SFW", "NSFW"]

CONFIG = Config()
print(f"Using device: {CONFIG.device}")

Using device: cuda


In [2]:
class NSFWDataset(Dataset):
    def __init__(
        self,
        image_paths: List[str],
        labels: List[int],
        processor: ViTImageProcessor,
        augment: bool = False,
    ):
        self.image_paths = image_paths
        self.labels = labels
        self.processor = processor

        if augment:
            self.transform = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(15),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
                transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
            ])
        else:
            self.transform = transforms.Compose([
                transforms.Resize((224, 224)),
            ])

    def __len__(self) -> int:
        return len(self.image_paths)

    def __getitem__(self, idx: int):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert("RGB")
            image = self.transform(image)
            processed = self.processor(images=image, return_tensors="pt")
            pixel_values = processed["pixel_values"].squeeze(0)
            return pixel_values, self.labels[idx]
        except Exception as e:
            logger.error(f"Error loading image {img_path}: {e}")
            return None

In [3]:
def load_image_paths_and_labels(config: Config) -> Tuple[List[str], List[int]]:
    nsfw_paths = []
    sfw_paths = []
    override_paths = []
    valid_extensions = {".jpg", ".jpeg", ".png", ".webp"}

    nsfw_path = Path(config.nsfw_dir)
    if nsfw_path.exists():
        nsfw_paths = [str(f) for f in nsfw_path.iterdir() if f.suffix.lower() in valid_extensions or f.name.startswith("did:")]
    print(f"NSFW images: {len(nsfw_paths)}")

    sfw_humans_path = Path(config.sfw_humans_dir)
    if sfw_humans_path.exists():
        sfw_paths += [str(f) for f in sfw_humans_path.iterdir() if f.suffix.lower() in valid_extensions or f.name.startswith("did:")]
    
    sfw_anime_path = Path(config.sfw_anime_dir)
    if sfw_anime_path.exists():
        sfw_paths += [str(f) for f in sfw_anime_path.iterdir() if f.suffix.lower() in valid_extensions or f.name.startswith("did:")]
    print(f"SFW images: {len(sfw_paths)}")

    nsfw_count, sfw_count = len(nsfw_paths), len(sfw_paths)
    # if nsfw_count > 0 and sfw_count > 0:
    #     ratio = nsfw_count / sfw_count
    #     target_ratio = 1.2
    #     if ratio > 2.0 or ratio < 0.5:
    #         if ratio > target_ratio:
    #             nsfw_paths = random.sample(nsfw_paths, int(sfw_count * target_ratio))
    #         else:
    #             sfw_paths = random.sample(sfw_paths, int(nsfw_count / target_ratio))

    image_paths = nsfw_paths + sfw_paths
    labels = [1] * len(nsfw_paths) + [0] * len(sfw_paths) + [0]

    combined = list(zip(image_paths, labels))
    random.shuffle(combined)
    image_paths, labels = zip(*combined)
    
    print(f"Total images: {len(image_paths)}")
    return list(image_paths), list(labels)

In [4]:
def create_dataloaders(image_paths, labels, processor, config, augment=False):
    total_size = len(image_paths)
    train_size = int(config.train_ratio * total_size)
    val_size = int(config.val_ratio * total_size)
    test_size = total_size - train_size - val_size

    full_dataset = NSFWDataset(image_paths, labels, processor, augment=False)
    train_temp, val_temp, test_temp = random_split(
        full_dataset, [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )

    train_dataset = NSFWDataset(
        [image_paths[i] for i in train_temp.indices],
        [labels[i] for i in train_temp.indices],
        processor, augment=augment
    )
    val_dataset = NSFWDataset(
        [image_paths[i] for i in val_temp.indices],
        [labels[i] for i in val_temp.indices],
        processor, augment=False
    )
    test_dataset = NSFWDataset(
        [image_paths[i] for i in test_temp.indices],
        [labels[i] for i in test_temp.indices],
        processor, augment=False
    )

    pin_memory = config.device.type == "cuda"
    
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, 
                              num_workers=config.num_workers, pin_memory=pin_memory)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False,
                            num_workers=config.num_workers, pin_memory=pin_memory)
    test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False,
                             num_workers=config.num_workers, pin_memory=pin_memory)

    print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")
    return train_loader, val_loader, test_loader

In [5]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels in tqdm(train_loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(pixel_values=images)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.logits, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return running_loss / len(train_loader), 100.0 * correct / total

def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []

    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validating"):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(pixel_values=images)
            loss = criterion(outputs.logits, labels)

            running_loss += loss.item()
            _, predicted = torch.max(outputs.logits, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    return running_loss / len(val_loader), 100.0 * correct / total, all_preds, all_labels

In [6]:
def train_model(model, train_loader, val_loader, config):
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=2)

    best_val_acc = 0.0
    best_model_state = None
    history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}

    for epoch in range(config.num_epochs):
        print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
        
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, config.device)
        val_loss, val_acc, _, _ = validate(model, val_loader, criterion, config.device)
        
        scheduler.step(val_loss)
        
        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
            print(f"✓ New best model!")

    if best_model_state:
        model.load_state_dict(best_model_state)
    
    return model, history

In [7]:
processor = ViTImageProcessor.from_pretrained(CONFIG.model_name)
image_paths, labels = load_image_paths_and_labels(CONFIG)
train_loader, val_loader, test_loader = create_dataloaders(
    image_paths, labels, processor, CONFIG, augment=True
)

model = ViTForImageClassification.from_pretrained(
    CONFIG.model_name,
    num_labels=CONFIG.num_classes,
    ignore_mismatched_sizes=True,
).to(CONFIG.device)

print(f"Model loaded on {CONFIG.device}")

NSFW images: 13595
SFW images: 13017
Total images: 26612
Train: 18628, Val: 3991, Test: 3993
Model loaded on cuda


In [None]:
model, history = train_model(model, train_loader, val_loader, CONFIG)


Epoch 1/6


Training:   0%|          | 0/1165 [00:00<?, ?it/s]

Validating:   0%|          | 0/250 [00:00<?, ?it/s]

Train Loss: 0.3366 | Train Acc: 94.98%
Val Loss: 0.2687 | Val Acc: 96.37%
✓ New best model!

Epoch 2/6


Training:   0%|          | 0/1165 [00:00<?, ?it/s]

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(history["train_loss"], label="Train")
ax1.plot(history["val_loss"], label="Val")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax1.legend()
ax1.set_title("Loss")

ax2.plot(history["train_acc"], label="Train")
ax2.plot(history["val_acc"], label="Val")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("Accuracy (%)")
ax2.legend()
ax2.set_title("Accuracy")

plt.tight_layout()
plt.show()

In [None]:
def show_confusion_matrix(model, test_loader, config):
    model.eval()
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images = images.to(config.device)
            outputs = model(pixel_values=images)
            _, predicted = torch.max(outputs.logits, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=config.class_names, digits=4))
    
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=config.class_names, yticklabels=config.class_names)
    plt.title("Confusion Matrix")
    plt.ylabel("True Label")
    plt.xlabel("Predicted Label")
    plt.tight_layout()
    plt.show()
    
    accuracy = (all_preds == all_labels).sum() / len(all_labels)
    print(f"\nAccuracy: {accuracy:.2%}")
    
    return all_preds, all_labels

all_preds, all_labels = show_confusion_matrix(model, test_loader, CONFIG)

In [None]:
def review_misclassified(model, test_loader, config, top_n=10):
    model.eval()
    
    false_positives = []
    false_negatives = []
    
    dataset = test_loader.dataset
    
    with torch.no_grad():
        sample_idx = 0
        for images, labels in tqdm(test_loader, desc="Finding misclassified"):
            images = images.to(config.device)
            outputs = model(pixel_values=images)
            probs = torch.softmax(outputs.logits, dim=1)
            _, predicted = torch.max(outputs.logits, 1)
            
            for i in range(len(labels)):
                true_label = labels[i].item()
                pred_label = predicted[i].item()
                confidence = probs[i, pred_label].item()
                
                if true_label != pred_label:
                    img_path = dataset.image_paths[sample_idx]
                    
                    if true_label == 0 and pred_label == 1:
                        false_positives.append({
                            "path": img_path,
                            "confidence": confidence,
                            "nsfw_prob": probs[i, 1].item(),
                        })
                    elif true_label == 1 and pred_label == 0:
                        false_negatives.append({
                            "path": img_path,
                            "confidence": confidence,
                            "sfw_prob": probs[i, 0].item(),
                        })
                
                sample_idx += 1
    
    false_positives = sorted(false_positives, key=lambda x: x["confidence"], reverse=True)
    false_negatives = sorted(false_negatives, key=lambda x: x["confidence"], reverse=True)
    
    if false_positives:
        fig, axes = plt.subplots(2, 5, figsize=(15, 6))
        axes = axes.flatten()
        
        for i, fp in enumerate(false_positives[:top_n]):
            if i < len(axes):
                img = Image.open(fp["path"]).convert("RGB")
                axes[i].imshow(img)
                axes[i].set_title(f"NSFW prob: {fp['nsfw_prob']:.2%}", fontsize=9)
                axes[i].axis("off")
        
        for i in range(len(false_positives[:top_n]), len(axes)):
            axes[i].axis("off")
        
        plt.suptitle("False Positives (SFW → NSFW)", fontsize=14)
        plt.tight_layout()
        plt.show()
    else:
        print("No false positives found!")
    
    if false_negatives:
        fig, axes = plt.subplots(2, 5, figsize=(15, 6))
        axes = axes.flatten()
        
        for i, fn in enumerate(false_negatives[:top_n]):
            if i < len(axes):
                img = Image.open(fn["path"]).convert("RGB")
                axes[i].imshow(img)
                axes[i].set_title(f"SFW prob: {fn['sfw_prob']:.2%}", fontsize=9)
                axes[i].axis("off")
        
        for i in range(len(false_negatives[:top_n]), len(axes)):
            axes[i].axis("off")
        
        plt.suptitle("False Negatives (NSFW → SFW)", fontsize=14)
        plt.tight_layout()
        plt.show()
    else:
        print("No false negatives found!")
    
    print(f"\nTotal: {len(false_positives)} false positives, {len(false_negatives)} false negatives")
    
    return false_positives, false_negatives

false_positives, false_negatives = review_misclassified(model, test_loader, CONFIG, top_n=10)

In [None]:
print("FALSE POSITIVE PATHS:")
for fp in false_positives[:10]:
    print(f"  {fp['path']}")

print("\nFALSE NEGATIVE PATHS:")
for fn in false_negatives[:10]:
    print(f"  {fn['path']}")

In [None]:
def save_model(model, processor, path="./nsfw_classifier"):
    Path(path).mkdir(parents=True, exist_ok=True)
    model.save_pretrained(path)
    processor.save_pretrained(path)
    print(f"Model saved to {path}")

# save_model(model, processor, CONFIG.model_save_path)

In [None]:
def load_model(path="./nsfw_classifier"):
    model = ViTForImageClassification.from_pretrained(path).to(CONFIG.device)
    processor = ViTImageProcessor.from_pretrained(path)
    model.eval()
    print(f"Model loaded from {path}")
    return model, processor

# model, processor = load_model(CONFIG.model_save_path)

In [None]:
import requests
from io import BytesIO

def predict_url(url):
    model.eval()

    response = requests.get(url)
    response.raise_for_status()
    image = Image.open(BytesIO(response.content)).convert("RGB")
    
    inputs = processor(images=image, return_tensors="pt")
    with torch.no_grad():
        outputs = model(pixel_values=inputs["pixel_values"].to(CONFIG.device))
        probs = torch.softmax(outputs.logits, dim=1)[0]
    
    plt.figure(figsize=(6, 6))
    plt.imshow(image)
    plt.title(f"SFW: {probs[0]:.2%} | NSFW: {probs[1]:.2%}")
    plt.axis("off")
    plt.show()
    
    return {"sfw": probs[0].item(), "nsfw": probs[1].item()}

url_to_predict = """
https://cdn.bsky.app/img/feed_thumbnail/plain/did:plc:j55kvyxp44daiz4yvx4g4ari/bafkreiebtixeda6jgde6gtsbffza7yxozijbisewev7i5ykbyw7m5rsboi@jpeg
"""

predict_url(url_to_predict.strip())