{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "7ab47b82-29e6-473b-8c13-af2390900d1c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using device: cuda\n" ] } ], "source": [ "import logging\n", "import random\n", "from io import BytesIO\n", "from pathlib import Path\n", "from typing import Any, Dict, List, Optional, Set, Tuple, Union\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import seaborn as sns\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from PIL import Image\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "from torch.utils.data import DataLoader, Dataset, random_split\n", "from torchvision import transforms\n", "from tqdm.notebook import tqdm\n", "from transformers import ViTForImageClassification, ViTImageProcessor\n", "\n", "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n", "logger = logging.getLogger()\n", "\n", "class Config:\n", " def __init__(self):\n", " self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " self.model_name = \"Falconsai/nsfw_image_detection\"\n", " self.num_classes = 2\n", " self.batch_size = 16\n", " self.num_epochs = 6 # 10\n", " self.learning_rate = 5e-6 # 2e-5\n", " self.weight_decay = 0.01\n", " self.train_ratio = 0.7\n", " self.val_ratio = 0.15\n", " self.num_workers = 4\n", " self.model_save_path = \"./models/nsfw_vit_classifier\"\n", " self.nsfw_dir = \"./newdata/nsfw\"\n", " self.sfw_humans_dir = \"./newdata/sfw-human\"\n", " self.sfw_anime_dir = \"./newdata/sfw-anime\"\n", " self.class_names = [\"SFW\", \"NSFW\"]\n", "\n", "CONFIG = Config()\n", "print(f\"Using device: {CONFIG.device}\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "1ae26bf1-3e5e-4542-ab89-e9e5de6758fe", "metadata": {}, "outputs": [], "source": [ "class NSFWDataset(Dataset):\n", " def __init__(\n", " self,\n", " image_paths: List[str],\n", " labels: List[int],\n", " processor: ViTImageProcessor,\n", " augment: bool = False,\n", " ):\n", " self.image_paths = image_paths\n", " self.labels = labels\n", " self.processor = processor\n", "\n", " if augment:\n", " self.transform = transforms.Compose([\n", " transforms.RandomHorizontalFlip(p=0.5),\n", " transforms.RandomRotation(15),\n", " transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),\n", " transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),\n", " ])\n", " else:\n", " self.transform = transforms.Compose([\n", " transforms.Resize((224, 224)),\n", " ])\n", "\n", " def __len__(self) -> int:\n", " return len(self.image_paths)\n", "\n", " def __getitem__(self, idx: int):\n", " img_path = self.image_paths[idx]\n", " try:\n", " image = Image.open(img_path).convert(\"RGB\")\n", " image = self.transform(image)\n", " processed = self.processor(images=image, return_tensors=\"pt\")\n", " pixel_values = processed[\"pixel_values\"].squeeze(0)\n", " return pixel_values, self.labels[idx]\n", " except Exception as e:\n", " logger.error(f\"Error loading image {img_path}: {e}\")\n", " return None" ] }, { "cell_type": "code", "execution_count": 3, "id": "7fdad2e1-0976-46db-be9e-675ae59d09e2", "metadata": {}, "outputs": [], "source": [ "def load_image_paths_and_labels(config: Config) -> Tuple[List[str], List[int]]:\n", " nsfw_paths = []\n", " sfw_paths = []\n", " override_paths = []\n", " valid_extensions = {\".jpg\", \".jpeg\", \".png\", \".webp\"}\n", "\n", " nsfw_path = Path(config.nsfw_dir)\n", " if nsfw_path.exists():\n", " nsfw_paths = [str(f) for f in nsfw_path.iterdir() if f.suffix.lower() in valid_extensions or f.name.startswith(\"did:\")]\n", " print(f\"NSFW images: {len(nsfw_paths)}\")\n", "\n", " sfw_humans_path = Path(config.sfw_humans_dir)\n", " if sfw_humans_path.exists():\n", " sfw_paths += [str(f) for f in sfw_humans_path.iterdir() if f.suffix.lower() in valid_extensions or f.name.startswith(\"did:\")]\n", " \n", " sfw_anime_path = Path(config.sfw_anime_dir)\n", " if sfw_anime_path.exists():\n", " sfw_paths += [str(f) for f in sfw_anime_path.iterdir() if f.suffix.lower() in valid_extensions or f.name.startswith(\"did:\")]\n", " print(f\"SFW images: {len(sfw_paths)}\")\n", "\n", " nsfw_count, sfw_count = len(nsfw_paths), len(sfw_paths)\n", " # if nsfw_count > 0 and sfw_count > 0:\n", " # ratio = nsfw_count / sfw_count\n", " # target_ratio = 1.2\n", " # if ratio > 2.0 or ratio < 0.5:\n", " # if ratio > target_ratio:\n", " # nsfw_paths = random.sample(nsfw_paths, int(sfw_count * target_ratio))\n", " # else:\n", " # sfw_paths = random.sample(sfw_paths, int(nsfw_count / target_ratio))\n", "\n", " image_paths = nsfw_paths + sfw_paths\n", " labels = [1] * len(nsfw_paths) + [0] * len(sfw_paths) + [0]\n", "\n", " combined = list(zip(image_paths, labels))\n", " random.shuffle(combined)\n", " image_paths, labels = zip(*combined)\n", " \n", " print(f\"Total images: {len(image_paths)}\")\n", " return list(image_paths), list(labels)" ] }, { "cell_type": "code", "execution_count": 4, "id": "21b22a1e-b33a-4dd7-802c-7bd65eb85b59", "metadata": {}, "outputs": [], "source": [ "def create_dataloaders(image_paths, labels, processor, config, augment=False):\n", " total_size = len(image_paths)\n", " train_size = int(config.train_ratio * total_size)\n", " val_size = int(config.val_ratio * total_size)\n", " test_size = total_size - train_size - val_size\n", "\n", " full_dataset = NSFWDataset(image_paths, labels, processor, augment=False)\n", " train_temp, val_temp, test_temp = random_split(\n", " full_dataset, [train_size, val_size, test_size],\n", " generator=torch.Generator().manual_seed(42)\n", " )\n", "\n", " train_dataset = NSFWDataset(\n", " [image_paths[i] for i in train_temp.indices],\n", " [labels[i] for i in train_temp.indices],\n", " processor, augment=augment\n", " )\n", " val_dataset = NSFWDataset(\n", " [image_paths[i] for i in val_temp.indices],\n", " [labels[i] for i in val_temp.indices],\n", " processor, augment=False\n", " )\n", " test_dataset = NSFWDataset(\n", " [image_paths[i] for i in test_temp.indices],\n", " [labels[i] for i in test_temp.indices],\n", " processor, augment=False\n", " )\n", "\n", " pin_memory = config.device.type == \"cuda\"\n", " \n", " train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, \n", " num_workers=config.num_workers, pin_memory=pin_memory)\n", " val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False,\n", " num_workers=config.num_workers, pin_memory=pin_memory)\n", " test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False,\n", " num_workers=config.num_workers, pin_memory=pin_memory)\n", "\n", " print(f\"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}\")\n", " return train_loader, val_loader, test_loader" ] }, { "cell_type": "code", "execution_count": 5, "id": "43bef262-249c-471e-92b0-2782ea8b9241", "metadata": {}, "outputs": [], "source": [ "def train_epoch(model, train_loader, criterion, optimizer, device):\n", " model.train()\n", " running_loss, correct, total = 0.0, 0, 0\n", "\n", " for images, labels in tqdm(train_loader, desc=\"Training\"):\n", " images, labels = images.to(device), labels.to(device)\n", " \n", " optimizer.zero_grad()\n", " outputs = model(pixel_values=images)\n", " loss = criterion(outputs.logits, labels)\n", " loss.backward()\n", " optimizer.step()\n", "\n", " running_loss += loss.item()\n", " _, predicted = torch.max(outputs.logits, 1)\n", " total += labels.size(0)\n", " correct += (predicted == labels).sum().item()\n", "\n", " return running_loss / len(train_loader), 100.0 * correct / total\n", "\n", "def validate(model, val_loader, criterion, device):\n", " model.eval()\n", " running_loss, correct, total = 0.0, 0, 0\n", " all_preds, all_labels = [], []\n", "\n", " with torch.no_grad():\n", " for images, labels in tqdm(val_loader, desc=\"Validating\"):\n", " images, labels = images.to(device), labels.to(device)\n", " \n", " outputs = model(pixel_values=images)\n", " loss = criterion(outputs.logits, labels)\n", "\n", " running_loss += loss.item()\n", " _, predicted = torch.max(outputs.logits, 1)\n", " total += labels.size(0)\n", " correct += (predicted == labels).sum().item()\n", " \n", " all_preds.extend(predicted.cpu().numpy())\n", " all_labels.extend(labels.cpu().numpy())\n", "\n", " return running_loss / len(val_loader), 100.0 * correct / total, all_preds, all_labels" ] }, { "cell_type": "code", "execution_count": 6, "id": "d655026d-0a94-4f93-abe9-1dc4c0209c00", "metadata": {}, "outputs": [], "source": [ "def train_model(model, train_loader, val_loader, config):\n", " criterion = nn.CrossEntropyLoss(label_smoothing=0.1)\n", " optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)\n", " scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=\"min\", factor=0.5, patience=2)\n", "\n", " best_val_acc = 0.0\n", " best_model_state = None\n", " history = {\"train_loss\": [], \"train_acc\": [], \"val_loss\": [], \"val_acc\": []}\n", "\n", " for epoch in range(config.num_epochs):\n", " print(f\"\\nEpoch {epoch + 1}/{config.num_epochs}\")\n", " \n", " train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, config.device)\n", " val_loss, val_acc, _, _ = validate(model, val_loader, criterion, config.device)\n", " \n", " scheduler.step(val_loss)\n", " \n", " history[\"train_loss\"].append(train_loss)\n", " history[\"train_acc\"].append(train_acc)\n", " history[\"val_loss\"].append(val_loss)\n", " history[\"val_acc\"].append(val_acc)\n", "\n", " print(f\"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%\")\n", " print(f\"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%\")\n", "\n", " if val_acc > best_val_acc:\n", " best_val_acc = val_acc\n", " best_model_state = model.state_dict().copy()\n", " print(f\"✓ New best model!\")\n", "\n", " if best_model_state:\n", " model.load_state_dict(best_model_state)\n", " \n", " return model, history" ] }, { "cell_type": "code", "execution_count": 7, "id": "2da4a196-6947-4f85-8a69-0d9a0d8dfd97", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "NSFW images: 13595\n", "SFW images: 13017\n", "Total images: 26612\n", "Train: 18628, Val: 3991, Test: 3993\n", "Model loaded on cuda\n" ] } ], "source": [ "processor = ViTImageProcessor.from_pretrained(CONFIG.model_name)\n", "image_paths, labels = load_image_paths_and_labels(CONFIG)\n", "train_loader, val_loader, test_loader = create_dataloaders(\n", " image_paths, labels, processor, CONFIG, augment=True\n", ")\n", "\n", "model = ViTForImageClassification.from_pretrained(\n", " CONFIG.model_name,\n", " num_labels=CONFIG.num_classes,\n", " ignore_mismatched_sizes=True,\n", ").to(CONFIG.device)\n", "\n", "print(f\"Model loaded on {CONFIG.device}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d37bd659-fb18-4445-83b0-1044d9239ea7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "Epoch 1/6\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b71483a651d6426cabc38ecd118d0334", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Training: 0%| | 0/1165 [00:00