this repo has no description
at main 25 kB view raw
1{ 2 "cells": [ 3 { 4 "cell_type": "code", 5 "execution_count": 1, 6 "id": "7ab47b82-29e6-473b-8c13-af2390900d1c", 7 "metadata": {}, 8 "outputs": [ 9 { 10 "name": "stdout", 11 "output_type": "stream", 12 "text": [ 13 "Using device: cuda\n" 14 ] 15 } 16 ], 17 "source": [ 18 "import logging\n", 19 "import random\n", 20 "from io import BytesIO\n", 21 "from pathlib import Path\n", 22 "from typing import Any, Dict, List, Optional, Set, Tuple, Union\n", 23 "\n", 24 "import matplotlib.pyplot as plt\n", 25 "import numpy as np\n", 26 "import seaborn as sns\n", 27 "import torch\n", 28 "import torch.nn as nn\n", 29 "import torch.optim as optim\n", 30 "from PIL import Image\n", 31 "from sklearn.metrics import classification_report, confusion_matrix\n", 32 "from torch.utils.data import DataLoader, Dataset, random_split\n", 33 "from torchvision import transforms\n", 34 "from tqdm.notebook import tqdm\n", 35 "from transformers import ViTForImageClassification, ViTImageProcessor\n", 36 "\n", 37 "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n", 38 "logger = logging.getLogger()\n", 39 "\n", 40 "class Config:\n", 41 " def __init__(self):\n", 42 " self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", 43 " self.model_name = \"Falconsai/nsfw_image_detection\"\n", 44 " self.num_classes = 2\n", 45 " self.batch_size = 16\n", 46 " self.num_epochs = 6 # 10\n", 47 " self.learning_rate = 5e-6 # 2e-5\n", 48 " self.weight_decay = 0.01\n", 49 " self.train_ratio = 0.7\n", 50 " self.val_ratio = 0.15\n", 51 " self.num_workers = 4\n", 52 " self.model_save_path = \"./models/nsfw_vit_classifier\"\n", 53 " self.nsfw_dir = \"./newdata/nsfw\"\n", 54 " self.sfw_humans_dir = \"./newdata/sfw-human\"\n", 55 " self.sfw_anime_dir = \"./newdata/sfw-anime\"\n", 56 " self.class_names = [\"SFW\", \"NSFW\"]\n", 57 "\n", 58 "CONFIG = Config()\n", 59 "print(f\"Using device: {CONFIG.device}\")" 60 ] 61 }, 62 { 63 "cell_type": "code", 64 "execution_count": 2, 65 "id": "1ae26bf1-3e5e-4542-ab89-e9e5de6758fe", 66 "metadata": {}, 67 "outputs": [], 68 "source": [ 69 "class NSFWDataset(Dataset):\n", 70 " def __init__(\n", 71 " self,\n", 72 " image_paths: List[str],\n", 73 " labels: List[int],\n", 74 " processor: ViTImageProcessor,\n", 75 " augment: bool = False,\n", 76 " ):\n", 77 " self.image_paths = image_paths\n", 78 " self.labels = labels\n", 79 " self.processor = processor\n", 80 "\n", 81 " if augment:\n", 82 " self.transform = transforms.Compose([\n", 83 " transforms.RandomHorizontalFlip(p=0.5),\n", 84 " transforms.RandomRotation(15),\n", 85 " transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),\n", 86 " transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),\n", 87 " ])\n", 88 " else:\n", 89 " self.transform = transforms.Compose([\n", 90 " transforms.Resize((224, 224)),\n", 91 " ])\n", 92 "\n", 93 " def __len__(self) -> int:\n", 94 " return len(self.image_paths)\n", 95 "\n", 96 " def __getitem__(self, idx: int):\n", 97 " img_path = self.image_paths[idx]\n", 98 " try:\n", 99 " image = Image.open(img_path).convert(\"RGB\")\n", 100 " image = self.transform(image)\n", 101 " processed = self.processor(images=image, return_tensors=\"pt\")\n", 102 " pixel_values = processed[\"pixel_values\"].squeeze(0)\n", 103 " return pixel_values, self.labels[idx]\n", 104 " except Exception as e:\n", 105 " logger.error(f\"Error loading image {img_path}: {e}\")\n", 106 " return None" 107 ] 108 }, 109 { 110 "cell_type": "code", 111 "execution_count": 3, 112 "id": "7fdad2e1-0976-46db-be9e-675ae59d09e2", 113 "metadata": {}, 114 "outputs": [], 115 "source": [ 116 "def load_image_paths_and_labels(config: Config) -> Tuple[List[str], List[int]]:\n", 117 " nsfw_paths = []\n", 118 " sfw_paths = []\n", 119 " override_paths = []\n", 120 " valid_extensions = {\".jpg\", \".jpeg\", \".png\", \".webp\"}\n", 121 "\n", 122 " nsfw_path = Path(config.nsfw_dir)\n", 123 " if nsfw_path.exists():\n", 124 " nsfw_paths = [str(f) for f in nsfw_path.iterdir() if f.suffix.lower() in valid_extensions or f.name.startswith(\"did:\")]\n", 125 " print(f\"NSFW images: {len(nsfw_paths)}\")\n", 126 "\n", 127 " sfw_humans_path = Path(config.sfw_humans_dir)\n", 128 " if sfw_humans_path.exists():\n", 129 " sfw_paths += [str(f) for f in sfw_humans_path.iterdir() if f.suffix.lower() in valid_extensions or f.name.startswith(\"did:\")]\n", 130 " \n", 131 " sfw_anime_path = Path(config.sfw_anime_dir)\n", 132 " if sfw_anime_path.exists():\n", 133 " sfw_paths += [str(f) for f in sfw_anime_path.iterdir() if f.suffix.lower() in valid_extensions or f.name.startswith(\"did:\")]\n", 134 " print(f\"SFW images: {len(sfw_paths)}\")\n", 135 "\n", 136 " nsfw_count, sfw_count = len(nsfw_paths), len(sfw_paths)\n", 137 " # if nsfw_count > 0 and sfw_count > 0:\n", 138 " # ratio = nsfw_count / sfw_count\n", 139 " # target_ratio = 1.2\n", 140 " # if ratio > 2.0 or ratio < 0.5:\n", 141 " # if ratio > target_ratio:\n", 142 " # nsfw_paths = random.sample(nsfw_paths, int(sfw_count * target_ratio))\n", 143 " # else:\n", 144 " # sfw_paths = random.sample(sfw_paths, int(nsfw_count / target_ratio))\n", 145 "\n", 146 " image_paths = nsfw_paths + sfw_paths\n", 147 " labels = [1] * len(nsfw_paths) + [0] * len(sfw_paths) + [0]\n", 148 "\n", 149 " combined = list(zip(image_paths, labels))\n", 150 " random.shuffle(combined)\n", 151 " image_paths, labels = zip(*combined)\n", 152 " \n", 153 " print(f\"Total images: {len(image_paths)}\")\n", 154 " return list(image_paths), list(labels)" 155 ] 156 }, 157 { 158 "cell_type": "code", 159 "execution_count": 4, 160 "id": "21b22a1e-b33a-4dd7-802c-7bd65eb85b59", 161 "metadata": {}, 162 "outputs": [], 163 "source": [ 164 "def create_dataloaders(image_paths, labels, processor, config, augment=False):\n", 165 " total_size = len(image_paths)\n", 166 " train_size = int(config.train_ratio * total_size)\n", 167 " val_size = int(config.val_ratio * total_size)\n", 168 " test_size = total_size - train_size - val_size\n", 169 "\n", 170 " full_dataset = NSFWDataset(image_paths, labels, processor, augment=False)\n", 171 " train_temp, val_temp, test_temp = random_split(\n", 172 " full_dataset, [train_size, val_size, test_size],\n", 173 " generator=torch.Generator().manual_seed(42)\n", 174 " )\n", 175 "\n", 176 " train_dataset = NSFWDataset(\n", 177 " [image_paths[i] for i in train_temp.indices],\n", 178 " [labels[i] for i in train_temp.indices],\n", 179 " processor, augment=augment\n", 180 " )\n", 181 " val_dataset = NSFWDataset(\n", 182 " [image_paths[i] for i in val_temp.indices],\n", 183 " [labels[i] for i in val_temp.indices],\n", 184 " processor, augment=False\n", 185 " )\n", 186 " test_dataset = NSFWDataset(\n", 187 " [image_paths[i] for i in test_temp.indices],\n", 188 " [labels[i] for i in test_temp.indices],\n", 189 " processor, augment=False\n", 190 " )\n", 191 "\n", 192 " pin_memory = config.device.type == \"cuda\"\n", 193 " \n", 194 " train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, \n", 195 " num_workers=config.num_workers, pin_memory=pin_memory)\n", 196 " val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False,\n", 197 " num_workers=config.num_workers, pin_memory=pin_memory)\n", 198 " test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False,\n", 199 " num_workers=config.num_workers, pin_memory=pin_memory)\n", 200 "\n", 201 " print(f\"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}\")\n", 202 " return train_loader, val_loader, test_loader" 203 ] 204 }, 205 { 206 "cell_type": "code", 207 "execution_count": 5, 208 "id": "43bef262-249c-471e-92b0-2782ea8b9241", 209 "metadata": {}, 210 "outputs": [], 211 "source": [ 212 "def train_epoch(model, train_loader, criterion, optimizer, device):\n", 213 " model.train()\n", 214 " running_loss, correct, total = 0.0, 0, 0\n", 215 "\n", 216 " for images, labels in tqdm(train_loader, desc=\"Training\"):\n", 217 " images, labels = images.to(device), labels.to(device)\n", 218 " \n", 219 " optimizer.zero_grad()\n", 220 " outputs = model(pixel_values=images)\n", 221 " loss = criterion(outputs.logits, labels)\n", 222 " loss.backward()\n", 223 " optimizer.step()\n", 224 "\n", 225 " running_loss += loss.item()\n", 226 " _, predicted = torch.max(outputs.logits, 1)\n", 227 " total += labels.size(0)\n", 228 " correct += (predicted == labels).sum().item()\n", 229 "\n", 230 " return running_loss / len(train_loader), 100.0 * correct / total\n", 231 "\n", 232 "def validate(model, val_loader, criterion, device):\n", 233 " model.eval()\n", 234 " running_loss, correct, total = 0.0, 0, 0\n", 235 " all_preds, all_labels = [], []\n", 236 "\n", 237 " with torch.no_grad():\n", 238 " for images, labels in tqdm(val_loader, desc=\"Validating\"):\n", 239 " images, labels = images.to(device), labels.to(device)\n", 240 " \n", 241 " outputs = model(pixel_values=images)\n", 242 " loss = criterion(outputs.logits, labels)\n", 243 "\n", 244 " running_loss += loss.item()\n", 245 " _, predicted = torch.max(outputs.logits, 1)\n", 246 " total += labels.size(0)\n", 247 " correct += (predicted == labels).sum().item()\n", 248 " \n", 249 " all_preds.extend(predicted.cpu().numpy())\n", 250 " all_labels.extend(labels.cpu().numpy())\n", 251 "\n", 252 " return running_loss / len(val_loader), 100.0 * correct / total, all_preds, all_labels" 253 ] 254 }, 255 { 256 "cell_type": "code", 257 "execution_count": 6, 258 "id": "d655026d-0a94-4f93-abe9-1dc4c0209c00", 259 "metadata": {}, 260 "outputs": [], 261 "source": [ 262 "def train_model(model, train_loader, val_loader, config):\n", 263 " criterion = nn.CrossEntropyLoss(label_smoothing=0.1)\n", 264 " optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)\n", 265 " scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=\"min\", factor=0.5, patience=2)\n", 266 "\n", 267 " best_val_acc = 0.0\n", 268 " best_model_state = None\n", 269 " history = {\"train_loss\": [], \"train_acc\": [], \"val_loss\": [], \"val_acc\": []}\n", 270 "\n", 271 " for epoch in range(config.num_epochs):\n", 272 " print(f\"\\nEpoch {epoch + 1}/{config.num_epochs}\")\n", 273 " \n", 274 " train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, config.device)\n", 275 " val_loss, val_acc, _, _ = validate(model, val_loader, criterion, config.device)\n", 276 " \n", 277 " scheduler.step(val_loss)\n", 278 " \n", 279 " history[\"train_loss\"].append(train_loss)\n", 280 " history[\"train_acc\"].append(train_acc)\n", 281 " history[\"val_loss\"].append(val_loss)\n", 282 " history[\"val_acc\"].append(val_acc)\n", 283 "\n", 284 " print(f\"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%\")\n", 285 " print(f\"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%\")\n", 286 "\n", 287 " if val_acc > best_val_acc:\n", 288 " best_val_acc = val_acc\n", 289 " best_model_state = model.state_dict().copy()\n", 290 " print(f\"✓ New best model!\")\n", 291 "\n", 292 " if best_model_state:\n", 293 " model.load_state_dict(best_model_state)\n", 294 " \n", 295 " return model, history" 296 ] 297 }, 298 { 299 "cell_type": "code", 300 "execution_count": 7, 301 "id": "2da4a196-6947-4f85-8a69-0d9a0d8dfd97", 302 "metadata": {}, 303 "outputs": [ 304 { 305 "name": "stdout", 306 "output_type": "stream", 307 "text": [ 308 "NSFW images: 13595\n", 309 "SFW images: 13017\n", 310 "Total images: 26612\n", 311 "Train: 18628, Val: 3991, Test: 3993\n", 312 "Model loaded on cuda\n" 313 ] 314 } 315 ], 316 "source": [ 317 "processor = ViTImageProcessor.from_pretrained(CONFIG.model_name)\n", 318 "image_paths, labels = load_image_paths_and_labels(CONFIG)\n", 319 "train_loader, val_loader, test_loader = create_dataloaders(\n", 320 " image_paths, labels, processor, CONFIG, augment=True\n", 321 ")\n", 322 "\n", 323 "model = ViTForImageClassification.from_pretrained(\n", 324 " CONFIG.model_name,\n", 325 " num_labels=CONFIG.num_classes,\n", 326 " ignore_mismatched_sizes=True,\n", 327 ").to(CONFIG.device)\n", 328 "\n", 329 "print(f\"Model loaded on {CONFIG.device}\")" 330 ] 331 }, 332 { 333 "cell_type": "code", 334 "execution_count": null, 335 "id": "d37bd659-fb18-4445-83b0-1044d9239ea7", 336 "metadata": {}, 337 "outputs": [ 338 { 339 "name": "stdout", 340 "output_type": "stream", 341 "text": [ 342 "\n", 343 "Epoch 1/6\n" 344 ] 345 }, 346 { 347 "data": { 348 "application/vnd.jupyter.widget-view+json": { 349 "model_id": "b71483a651d6426cabc38ecd118d0334", 350 "version_major": 2, 351 "version_minor": 0 352 }, 353 "text/plain": [ 354 "Training: 0%| | 0/1165 [00:00<?, ?it/s]" 355 ] 356 }, 357 "metadata": {}, 358 "output_type": "display_data" 359 }, 360 { 361 "data": { 362 "application/vnd.jupyter.widget-view+json": { 363 "model_id": "e57878456bd444199c712aea7d679932", 364 "version_major": 2, 365 "version_minor": 0 366 }, 367 "text/plain": [ 368 "Validating: 0%| | 0/250 [00:00<?, ?it/s]" 369 ] 370 }, 371 "metadata": {}, 372 "output_type": "display_data" 373 }, 374 { 375 "name": "stdout", 376 "output_type": "stream", 377 "text": [ 378 "Train Loss: 0.3366 | Train Acc: 94.98%\n", 379 "Val Loss: 0.2687 | Val Acc: 96.37%\n", 380 "✓ New best model!\n", 381 "\n", 382 "Epoch 2/6\n" 383 ] 384 }, 385 { 386 "data": { 387 "application/vnd.jupyter.widget-view+json": { 388 "model_id": "1a73bac0a3d247acb1fd19fe4ee65b48", 389 "version_major": 2, 390 "version_minor": 0 391 }, 392 "text/plain": [ 393 "Training: 0%| | 0/1165 [00:00<?, ?it/s]" 394 ] 395 }, 396 "metadata": {}, 397 "output_type": "display_data" 398 } 399 ], 400 "source": [ 401 "model, history = train_model(model, train_loader, val_loader, CONFIG)" 402 ] 403 }, 404 { 405 "cell_type": "code", 406 "execution_count": null, 407 "id": "61ad9fdf-38a3-44c5-8ab9-56ef9767e170", 408 "metadata": {}, 409 "outputs": [], 410 "source": [ 411 "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n", 412 "\n", 413 "ax1.plot(history[\"train_loss\"], label=\"Train\")\n", 414 "ax1.plot(history[\"val_loss\"], label=\"Val\")\n", 415 "ax1.set_xlabel(\"Epoch\")\n", 416 "ax1.set_ylabel(\"Loss\")\n", 417 "ax1.legend()\n", 418 "ax1.set_title(\"Loss\")\n", 419 "\n", 420 "ax2.plot(history[\"train_acc\"], label=\"Train\")\n", 421 "ax2.plot(history[\"val_acc\"], label=\"Val\")\n", 422 "ax2.set_xlabel(\"Epoch\")\n", 423 "ax2.set_ylabel(\"Accuracy (%)\")\n", 424 "ax2.legend()\n", 425 "ax2.set_title(\"Accuracy\")\n", 426 "\n", 427 "plt.tight_layout()\n", 428 "plt.show()" 429 ] 430 }, 431 { 432 "cell_type": "code", 433 "execution_count": null, 434 "id": "d4a46d90-df37-4c11-91fd-1ef107d5b231", 435 "metadata": {}, 436 "outputs": [], 437 "source": [ 438 "def show_confusion_matrix(model, test_loader, config):\n", 439 " model.eval()\n", 440 " all_preds, all_labels = [], []\n", 441 " \n", 442 " with torch.no_grad():\n", 443 " for images, labels in tqdm(test_loader, desc=\"Evaluating\"):\n", 444 " images = images.to(config.device)\n", 445 " outputs = model(pixel_values=images)\n", 446 " _, predicted = torch.max(outputs.logits, 1)\n", 447 " \n", 448 " all_preds.extend(predicted.cpu().numpy())\n", 449 " all_labels.extend(labels.numpy())\n", 450 " \n", 451 " all_preds = np.array(all_preds)\n", 452 " all_labels = np.array(all_labels)\n", 453 " \n", 454 " print(\"Classification Report:\")\n", 455 " print(classification_report(all_labels, all_preds, target_names=config.class_names, digits=4))\n", 456 " \n", 457 " cm = confusion_matrix(all_labels, all_preds)\n", 458 " plt.figure(figsize=(8, 6))\n", 459 " sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\",\n", 460 " xticklabels=config.class_names, yticklabels=config.class_names)\n", 461 " plt.title(\"Confusion Matrix\")\n", 462 " plt.ylabel(\"True Label\")\n", 463 " plt.xlabel(\"Predicted Label\")\n", 464 " plt.tight_layout()\n", 465 " plt.show()\n", 466 " \n", 467 " accuracy = (all_preds == all_labels).sum() / len(all_labels)\n", 468 " print(f\"\\nAccuracy: {accuracy:.2%}\")\n", 469 " \n", 470 " return all_preds, all_labels\n", 471 "\n", 472 "all_preds, all_labels = show_confusion_matrix(model, test_loader, CONFIG)" 473 ] 474 }, 475 { 476 "cell_type": "code", 477 "execution_count": null, 478 "id": "93836555-7fbe-4cbb-b7ac-91d3e06b1da6", 479 "metadata": {}, 480 "outputs": [], 481 "source": [ 482 "def review_misclassified(model, test_loader, config, top_n=10):\n", 483 " model.eval()\n", 484 " \n", 485 " false_positives = []\n", 486 " false_negatives = []\n", 487 " \n", 488 " dataset = test_loader.dataset\n", 489 " \n", 490 " with torch.no_grad():\n", 491 " sample_idx = 0\n", 492 " for images, labels in tqdm(test_loader, desc=\"Finding misclassified\"):\n", 493 " images = images.to(config.device)\n", 494 " outputs = model(pixel_values=images)\n", 495 " probs = torch.softmax(outputs.logits, dim=1)\n", 496 " _, predicted = torch.max(outputs.logits, 1)\n", 497 " \n", 498 " for i in range(len(labels)):\n", 499 " true_label = labels[i].item()\n", 500 " pred_label = predicted[i].item()\n", 501 " confidence = probs[i, pred_label].item()\n", 502 " \n", 503 " if true_label != pred_label:\n", 504 " img_path = dataset.image_paths[sample_idx]\n", 505 " \n", 506 " if true_label == 0 and pred_label == 1:\n", 507 " false_positives.append({\n", 508 " \"path\": img_path,\n", 509 " \"confidence\": confidence,\n", 510 " \"nsfw_prob\": probs[i, 1].item(),\n", 511 " })\n", 512 " elif true_label == 1 and pred_label == 0:\n", 513 " false_negatives.append({\n", 514 " \"path\": img_path,\n", 515 " \"confidence\": confidence,\n", 516 " \"sfw_prob\": probs[i, 0].item(),\n", 517 " })\n", 518 " \n", 519 " sample_idx += 1\n", 520 " \n", 521 " false_positives = sorted(false_positives, key=lambda x: x[\"confidence\"], reverse=True)\n", 522 " false_negatives = sorted(false_negatives, key=lambda x: x[\"confidence\"], reverse=True)\n", 523 " \n", 524 " if false_positives:\n", 525 " fig, axes = plt.subplots(2, 5, figsize=(15, 6))\n", 526 " axes = axes.flatten()\n", 527 " \n", 528 " for i, fp in enumerate(false_positives[:top_n]):\n", 529 " if i < len(axes):\n", 530 " img = Image.open(fp[\"path\"]).convert(\"RGB\")\n", 531 " axes[i].imshow(img)\n", 532 " axes[i].set_title(f\"NSFW prob: {fp['nsfw_prob']:.2%}\", fontsize=9)\n", 533 " axes[i].axis(\"off\")\n", 534 " \n", 535 " for i in range(len(false_positives[:top_n]), len(axes)):\n", 536 " axes[i].axis(\"off\")\n", 537 " \n", 538 " plt.suptitle(\"False Positives (SFW → NSFW)\", fontsize=14)\n", 539 " plt.tight_layout()\n", 540 " plt.show()\n", 541 " else:\n", 542 " print(\"No false positives found!\")\n", 543 " \n", 544 " if false_negatives:\n", 545 " fig, axes = plt.subplots(2, 5, figsize=(15, 6))\n", 546 " axes = axes.flatten()\n", 547 " \n", 548 " for i, fn in enumerate(false_negatives[:top_n]):\n", 549 " if i < len(axes):\n", 550 " img = Image.open(fn[\"path\"]).convert(\"RGB\")\n", 551 " axes[i].imshow(img)\n", 552 " axes[i].set_title(f\"SFW prob: {fn['sfw_prob']:.2%}\", fontsize=9)\n", 553 " axes[i].axis(\"off\")\n", 554 " \n", 555 " for i in range(len(false_negatives[:top_n]), len(axes)):\n", 556 " axes[i].axis(\"off\")\n", 557 " \n", 558 " plt.suptitle(\"False Negatives (NSFW → SFW)\", fontsize=14)\n", 559 " plt.tight_layout()\n", 560 " plt.show()\n", 561 " else:\n", 562 " print(\"No false negatives found!\")\n", 563 " \n", 564 " print(f\"\\nTotal: {len(false_positives)} false positives, {len(false_negatives)} false negatives\")\n", 565 " \n", 566 " return false_positives, false_negatives\n", 567 "\n", 568 "false_positives, false_negatives = review_misclassified(model, test_loader, CONFIG, top_n=10)" 569 ] 570 }, 571 { 572 "cell_type": "code", 573 "execution_count": null, 574 "id": "69e5ce7b-fe32-4d01-ad5e-d1055f06b4ef", 575 "metadata": {}, 576 "outputs": [], 577 "source": [ 578 "print(\"FALSE POSITIVE PATHS:\")\n", 579 "for fp in false_positives[:10]:\n", 580 " print(f\" {fp['path']}\")\n", 581 "\n", 582 "print(\"\\nFALSE NEGATIVE PATHS:\")\n", 583 "for fn in false_negatives[:10]:\n", 584 " print(f\" {fn['path']}\")" 585 ] 586 }, 587 { 588 "cell_type": "code", 589 "execution_count": null, 590 "id": "c06bcf4f-5acb-4827-8991-e85b9b7b7fbd", 591 "metadata": {}, 592 "outputs": [], 593 "source": [ 594 "def save_model(model, processor, path=\"./nsfw_classifier\"):\n", 595 " Path(path).mkdir(parents=True, exist_ok=True)\n", 596 " model.save_pretrained(path)\n", 597 " processor.save_pretrained(path)\n", 598 " print(f\"Model saved to {path}\")\n", 599 "\n", 600 "# save_model(model, processor, CONFIG.model_save_path)" 601 ] 602 }, 603 { 604 "cell_type": "code", 605 "execution_count": null, 606 "id": "50fbe061-0435-4b0b-ab66-0bb773ba6ab0", 607 "metadata": {}, 608 "outputs": [], 609 "source": [ 610 "def load_model(path=\"./nsfw_classifier\"):\n", 611 " model = ViTForImageClassification.from_pretrained(path).to(CONFIG.device)\n", 612 " processor = ViTImageProcessor.from_pretrained(path)\n", 613 " model.eval()\n", 614 " print(f\"Model loaded from {path}\")\n", 615 " return model, processor\n", 616 "\n", 617 "# model, processor = load_model(CONFIG.model_save_path)" 618 ] 619 }, 620 { 621 "cell_type": "code", 622 "execution_count": null, 623 "id": "50fe4bec-c29e-453b-beca-16304c1c7a1a", 624 "metadata": {}, 625 "outputs": [], 626 "source": [ 627 "import requests\n", 628 "from io import BytesIO\n", 629 "\n", 630 "def predict_url(url):\n", 631 " model.eval()\n", 632 "\n", 633 " response = requests.get(url)\n", 634 " response.raise_for_status()\n", 635 " image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n", 636 " \n", 637 " inputs = processor(images=image, return_tensors=\"pt\")\n", 638 " with torch.no_grad():\n", 639 " outputs = model(pixel_values=inputs[\"pixel_values\"].to(CONFIG.device))\n", 640 " probs = torch.softmax(outputs.logits, dim=1)[0]\n", 641 " \n", 642 " plt.figure(figsize=(6, 6))\n", 643 " plt.imshow(image)\n", 644 " plt.title(f\"SFW: {probs[0]:.2%} | NSFW: {probs[1]:.2%}\")\n", 645 " plt.axis(\"off\")\n", 646 " plt.show()\n", 647 " \n", 648 " return {\"sfw\": probs[0].item(), \"nsfw\": probs[1].item()}\n", 649 "\n", 650 "url_to_predict = \"\"\"\n", 651 "https://cdn.bsky.app/img/feed_thumbnail/plain/did:plc:j55kvyxp44daiz4yvx4g4ari/bafkreiebtixeda6jgde6gtsbffza7yxozijbisewev7i5ykbyw7m5rsboi@jpeg\n", 652 "\"\"\"\n", 653 "\n", 654 "predict_url(url_to_predict.strip())" 655 ] 656 } 657 ], 658 "metadata": { 659 "kernelspec": { 660 "display_name": "Python 3 (ipykernel)", 661 "language": "python", 662 "name": "python3" 663 }, 664 "language_info": { 665 "codemirror_mode": { 666 "name": "ipython", 667 "version": 3 668 }, 669 "file_extension": ".py", 670 "mimetype": "text/x-python", 671 "name": "python", 672 "nbconvert_exporter": "python", 673 "pygments_lexer": "ipython3", 674 "version": "3.12.7" 675 } 676 }, 677 "nbformat": 4, 678 "nbformat_minor": 5 679}