this repo has no description
1{
2 "cells": [
3 {
4 "cell_type": "code",
5 "execution_count": 1,
6 "id": "7ab47b82-29e6-473b-8c13-af2390900d1c",
7 "metadata": {},
8 "outputs": [
9 {
10 "name": "stdout",
11 "output_type": "stream",
12 "text": [
13 "Using device: cuda\n"
14 ]
15 }
16 ],
17 "source": [
18 "import logging\n",
19 "import random\n",
20 "from io import BytesIO\n",
21 "from pathlib import Path\n",
22 "from typing import Any, Dict, List, Optional, Set, Tuple, Union\n",
23 "\n",
24 "import matplotlib.pyplot as plt\n",
25 "import numpy as np\n",
26 "import seaborn as sns\n",
27 "import torch\n",
28 "import torch.nn as nn\n",
29 "import torch.optim as optim\n",
30 "from PIL import Image\n",
31 "from sklearn.metrics import classification_report, confusion_matrix\n",
32 "from torch.utils.data import DataLoader, Dataset, random_split\n",
33 "from torchvision import transforms\n",
34 "from tqdm.notebook import tqdm\n",
35 "from transformers import ViTForImageClassification, ViTImageProcessor\n",
36 "\n",
37 "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n",
38 "logger = logging.getLogger()\n",
39 "\n",
40 "class Config:\n",
41 " def __init__(self):\n",
42 " self.device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
43 " self.model_name = \"Falconsai/nsfw_image_detection\"\n",
44 " self.num_classes = 2\n",
45 " self.batch_size = 16\n",
46 " self.num_epochs = 6 # 10\n",
47 " self.learning_rate = 5e-6 # 2e-5\n",
48 " self.weight_decay = 0.01\n",
49 " self.train_ratio = 0.7\n",
50 " self.val_ratio = 0.15\n",
51 " self.num_workers = 4\n",
52 " self.model_save_path = \"./models/nsfw_vit_classifier\"\n",
53 " self.nsfw_dir = \"./newdata/nsfw\"\n",
54 " self.sfw_humans_dir = \"./newdata/sfw-human\"\n",
55 " self.sfw_anime_dir = \"./newdata/sfw-anime\"\n",
56 " self.class_names = [\"SFW\", \"NSFW\"]\n",
57 "\n",
58 "CONFIG = Config()\n",
59 "print(f\"Using device: {CONFIG.device}\")"
60 ]
61 },
62 {
63 "cell_type": "code",
64 "execution_count": 2,
65 "id": "1ae26bf1-3e5e-4542-ab89-e9e5de6758fe",
66 "metadata": {},
67 "outputs": [],
68 "source": [
69 "class NSFWDataset(Dataset):\n",
70 " def __init__(\n",
71 " self,\n",
72 " image_paths: List[str],\n",
73 " labels: List[int],\n",
74 " processor: ViTImageProcessor,\n",
75 " augment: bool = False,\n",
76 " ):\n",
77 " self.image_paths = image_paths\n",
78 " self.labels = labels\n",
79 " self.processor = processor\n",
80 "\n",
81 " if augment:\n",
82 " self.transform = transforms.Compose([\n",
83 " transforms.RandomHorizontalFlip(p=0.5),\n",
84 " transforms.RandomRotation(15),\n",
85 " transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),\n",
86 " transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),\n",
87 " ])\n",
88 " else:\n",
89 " self.transform = transforms.Compose([\n",
90 " transforms.Resize((224, 224)),\n",
91 " ])\n",
92 "\n",
93 " def __len__(self) -> int:\n",
94 " return len(self.image_paths)\n",
95 "\n",
96 " def __getitem__(self, idx: int):\n",
97 " img_path = self.image_paths[idx]\n",
98 " try:\n",
99 " image = Image.open(img_path).convert(\"RGB\")\n",
100 " image = self.transform(image)\n",
101 " processed = self.processor(images=image, return_tensors=\"pt\")\n",
102 " pixel_values = processed[\"pixel_values\"].squeeze(0)\n",
103 " return pixel_values, self.labels[idx]\n",
104 " except Exception as e:\n",
105 " logger.error(f\"Error loading image {img_path}: {e}\")\n",
106 " return None"
107 ]
108 },
109 {
110 "cell_type": "code",
111 "execution_count": 3,
112 "id": "7fdad2e1-0976-46db-be9e-675ae59d09e2",
113 "metadata": {},
114 "outputs": [],
115 "source": [
116 "def load_image_paths_and_labels(config: Config) -> Tuple[List[str], List[int]]:\n",
117 " nsfw_paths = []\n",
118 " sfw_paths = []\n",
119 " override_paths = []\n",
120 " valid_extensions = {\".jpg\", \".jpeg\", \".png\", \".webp\"}\n",
121 "\n",
122 " nsfw_path = Path(config.nsfw_dir)\n",
123 " if nsfw_path.exists():\n",
124 " nsfw_paths = [str(f) for f in nsfw_path.iterdir() if f.suffix.lower() in valid_extensions or f.name.startswith(\"did:\")]\n",
125 " print(f\"NSFW images: {len(nsfw_paths)}\")\n",
126 "\n",
127 " sfw_humans_path = Path(config.sfw_humans_dir)\n",
128 " if sfw_humans_path.exists():\n",
129 " sfw_paths += [str(f) for f in sfw_humans_path.iterdir() if f.suffix.lower() in valid_extensions or f.name.startswith(\"did:\")]\n",
130 " \n",
131 " sfw_anime_path = Path(config.sfw_anime_dir)\n",
132 " if sfw_anime_path.exists():\n",
133 " sfw_paths += [str(f) for f in sfw_anime_path.iterdir() if f.suffix.lower() in valid_extensions or f.name.startswith(\"did:\")]\n",
134 " print(f\"SFW images: {len(sfw_paths)}\")\n",
135 "\n",
136 " nsfw_count, sfw_count = len(nsfw_paths), len(sfw_paths)\n",
137 " # if nsfw_count > 0 and sfw_count > 0:\n",
138 " # ratio = nsfw_count / sfw_count\n",
139 " # target_ratio = 1.2\n",
140 " # if ratio > 2.0 or ratio < 0.5:\n",
141 " # if ratio > target_ratio:\n",
142 " # nsfw_paths = random.sample(nsfw_paths, int(sfw_count * target_ratio))\n",
143 " # else:\n",
144 " # sfw_paths = random.sample(sfw_paths, int(nsfw_count / target_ratio))\n",
145 "\n",
146 " image_paths = nsfw_paths + sfw_paths\n",
147 " labels = [1] * len(nsfw_paths) + [0] * len(sfw_paths) + [0]\n",
148 "\n",
149 " combined = list(zip(image_paths, labels))\n",
150 " random.shuffle(combined)\n",
151 " image_paths, labels = zip(*combined)\n",
152 " \n",
153 " print(f\"Total images: {len(image_paths)}\")\n",
154 " return list(image_paths), list(labels)"
155 ]
156 },
157 {
158 "cell_type": "code",
159 "execution_count": 4,
160 "id": "21b22a1e-b33a-4dd7-802c-7bd65eb85b59",
161 "metadata": {},
162 "outputs": [],
163 "source": [
164 "def create_dataloaders(image_paths, labels, processor, config, augment=False):\n",
165 " total_size = len(image_paths)\n",
166 " train_size = int(config.train_ratio * total_size)\n",
167 " val_size = int(config.val_ratio * total_size)\n",
168 " test_size = total_size - train_size - val_size\n",
169 "\n",
170 " full_dataset = NSFWDataset(image_paths, labels, processor, augment=False)\n",
171 " train_temp, val_temp, test_temp = random_split(\n",
172 " full_dataset, [train_size, val_size, test_size],\n",
173 " generator=torch.Generator().manual_seed(42)\n",
174 " )\n",
175 "\n",
176 " train_dataset = NSFWDataset(\n",
177 " [image_paths[i] for i in train_temp.indices],\n",
178 " [labels[i] for i in train_temp.indices],\n",
179 " processor, augment=augment\n",
180 " )\n",
181 " val_dataset = NSFWDataset(\n",
182 " [image_paths[i] for i in val_temp.indices],\n",
183 " [labels[i] for i in val_temp.indices],\n",
184 " processor, augment=False\n",
185 " )\n",
186 " test_dataset = NSFWDataset(\n",
187 " [image_paths[i] for i in test_temp.indices],\n",
188 " [labels[i] for i in test_temp.indices],\n",
189 " processor, augment=False\n",
190 " )\n",
191 "\n",
192 " pin_memory = config.device.type == \"cuda\"\n",
193 " \n",
194 " train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, \n",
195 " num_workers=config.num_workers, pin_memory=pin_memory)\n",
196 " val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False,\n",
197 " num_workers=config.num_workers, pin_memory=pin_memory)\n",
198 " test_loader = DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False,\n",
199 " num_workers=config.num_workers, pin_memory=pin_memory)\n",
200 "\n",
201 " print(f\"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}\")\n",
202 " return train_loader, val_loader, test_loader"
203 ]
204 },
205 {
206 "cell_type": "code",
207 "execution_count": 5,
208 "id": "43bef262-249c-471e-92b0-2782ea8b9241",
209 "metadata": {},
210 "outputs": [],
211 "source": [
212 "def train_epoch(model, train_loader, criterion, optimizer, device):\n",
213 " model.train()\n",
214 " running_loss, correct, total = 0.0, 0, 0\n",
215 "\n",
216 " for images, labels in tqdm(train_loader, desc=\"Training\"):\n",
217 " images, labels = images.to(device), labels.to(device)\n",
218 " \n",
219 " optimizer.zero_grad()\n",
220 " outputs = model(pixel_values=images)\n",
221 " loss = criterion(outputs.logits, labels)\n",
222 " loss.backward()\n",
223 " optimizer.step()\n",
224 "\n",
225 " running_loss += loss.item()\n",
226 " _, predicted = torch.max(outputs.logits, 1)\n",
227 " total += labels.size(0)\n",
228 " correct += (predicted == labels).sum().item()\n",
229 "\n",
230 " return running_loss / len(train_loader), 100.0 * correct / total\n",
231 "\n",
232 "def validate(model, val_loader, criterion, device):\n",
233 " model.eval()\n",
234 " running_loss, correct, total = 0.0, 0, 0\n",
235 " all_preds, all_labels = [], []\n",
236 "\n",
237 " with torch.no_grad():\n",
238 " for images, labels in tqdm(val_loader, desc=\"Validating\"):\n",
239 " images, labels = images.to(device), labels.to(device)\n",
240 " \n",
241 " outputs = model(pixel_values=images)\n",
242 " loss = criterion(outputs.logits, labels)\n",
243 "\n",
244 " running_loss += loss.item()\n",
245 " _, predicted = torch.max(outputs.logits, 1)\n",
246 " total += labels.size(0)\n",
247 " correct += (predicted == labels).sum().item()\n",
248 " \n",
249 " all_preds.extend(predicted.cpu().numpy())\n",
250 " all_labels.extend(labels.cpu().numpy())\n",
251 "\n",
252 " return running_loss / len(val_loader), 100.0 * correct / total, all_preds, all_labels"
253 ]
254 },
255 {
256 "cell_type": "code",
257 "execution_count": 6,
258 "id": "d655026d-0a94-4f93-abe9-1dc4c0209c00",
259 "metadata": {},
260 "outputs": [],
261 "source": [
262 "def train_model(model, train_loader, val_loader, config):\n",
263 " criterion = nn.CrossEntropyLoss(label_smoothing=0.1)\n",
264 " optimizer = optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)\n",
265 " scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=\"min\", factor=0.5, patience=2)\n",
266 "\n",
267 " best_val_acc = 0.0\n",
268 " best_model_state = None\n",
269 " history = {\"train_loss\": [], \"train_acc\": [], \"val_loss\": [], \"val_acc\": []}\n",
270 "\n",
271 " for epoch in range(config.num_epochs):\n",
272 " print(f\"\\nEpoch {epoch + 1}/{config.num_epochs}\")\n",
273 " \n",
274 " train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, config.device)\n",
275 " val_loss, val_acc, _, _ = validate(model, val_loader, criterion, config.device)\n",
276 " \n",
277 " scheduler.step(val_loss)\n",
278 " \n",
279 " history[\"train_loss\"].append(train_loss)\n",
280 " history[\"train_acc\"].append(train_acc)\n",
281 " history[\"val_loss\"].append(val_loss)\n",
282 " history[\"val_acc\"].append(val_acc)\n",
283 "\n",
284 " print(f\"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%\")\n",
285 " print(f\"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%\")\n",
286 "\n",
287 " if val_acc > best_val_acc:\n",
288 " best_val_acc = val_acc\n",
289 " best_model_state = model.state_dict().copy()\n",
290 " print(f\"✓ New best model!\")\n",
291 "\n",
292 " if best_model_state:\n",
293 " model.load_state_dict(best_model_state)\n",
294 " \n",
295 " return model, history"
296 ]
297 },
298 {
299 "cell_type": "code",
300 "execution_count": 7,
301 "id": "2da4a196-6947-4f85-8a69-0d9a0d8dfd97",
302 "metadata": {},
303 "outputs": [
304 {
305 "name": "stdout",
306 "output_type": "stream",
307 "text": [
308 "NSFW images: 13595\n",
309 "SFW images: 13017\n",
310 "Total images: 26612\n",
311 "Train: 18628, Val: 3991, Test: 3993\n",
312 "Model loaded on cuda\n"
313 ]
314 }
315 ],
316 "source": [
317 "processor = ViTImageProcessor.from_pretrained(CONFIG.model_name)\n",
318 "image_paths, labels = load_image_paths_and_labels(CONFIG)\n",
319 "train_loader, val_loader, test_loader = create_dataloaders(\n",
320 " image_paths, labels, processor, CONFIG, augment=True\n",
321 ")\n",
322 "\n",
323 "model = ViTForImageClassification.from_pretrained(\n",
324 " CONFIG.model_name,\n",
325 " num_labels=CONFIG.num_classes,\n",
326 " ignore_mismatched_sizes=True,\n",
327 ").to(CONFIG.device)\n",
328 "\n",
329 "print(f\"Model loaded on {CONFIG.device}\")"
330 ]
331 },
332 {
333 "cell_type": "code",
334 "execution_count": null,
335 "id": "d37bd659-fb18-4445-83b0-1044d9239ea7",
336 "metadata": {},
337 "outputs": [
338 {
339 "name": "stdout",
340 "output_type": "stream",
341 "text": [
342 "\n",
343 "Epoch 1/6\n"
344 ]
345 },
346 {
347 "data": {
348 "application/vnd.jupyter.widget-view+json": {
349 "model_id": "b71483a651d6426cabc38ecd118d0334",
350 "version_major": 2,
351 "version_minor": 0
352 },
353 "text/plain": [
354 "Training: 0%| | 0/1165 [00:00<?, ?it/s]"
355 ]
356 },
357 "metadata": {},
358 "output_type": "display_data"
359 },
360 {
361 "data": {
362 "application/vnd.jupyter.widget-view+json": {
363 "model_id": "e57878456bd444199c712aea7d679932",
364 "version_major": 2,
365 "version_minor": 0
366 },
367 "text/plain": [
368 "Validating: 0%| | 0/250 [00:00<?, ?it/s]"
369 ]
370 },
371 "metadata": {},
372 "output_type": "display_data"
373 },
374 {
375 "name": "stdout",
376 "output_type": "stream",
377 "text": [
378 "Train Loss: 0.3366 | Train Acc: 94.98%\n",
379 "Val Loss: 0.2687 | Val Acc: 96.37%\n",
380 "✓ New best model!\n",
381 "\n",
382 "Epoch 2/6\n"
383 ]
384 },
385 {
386 "data": {
387 "application/vnd.jupyter.widget-view+json": {
388 "model_id": "1a73bac0a3d247acb1fd19fe4ee65b48",
389 "version_major": 2,
390 "version_minor": 0
391 },
392 "text/plain": [
393 "Training: 0%| | 0/1165 [00:00<?, ?it/s]"
394 ]
395 },
396 "metadata": {},
397 "output_type": "display_data"
398 }
399 ],
400 "source": [
401 "model, history = train_model(model, train_loader, val_loader, CONFIG)"
402 ]
403 },
404 {
405 "cell_type": "code",
406 "execution_count": null,
407 "id": "61ad9fdf-38a3-44c5-8ab9-56ef9767e170",
408 "metadata": {},
409 "outputs": [],
410 "source": [
411 "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))\n",
412 "\n",
413 "ax1.plot(history[\"train_loss\"], label=\"Train\")\n",
414 "ax1.plot(history[\"val_loss\"], label=\"Val\")\n",
415 "ax1.set_xlabel(\"Epoch\")\n",
416 "ax1.set_ylabel(\"Loss\")\n",
417 "ax1.legend()\n",
418 "ax1.set_title(\"Loss\")\n",
419 "\n",
420 "ax2.plot(history[\"train_acc\"], label=\"Train\")\n",
421 "ax2.plot(history[\"val_acc\"], label=\"Val\")\n",
422 "ax2.set_xlabel(\"Epoch\")\n",
423 "ax2.set_ylabel(\"Accuracy (%)\")\n",
424 "ax2.legend()\n",
425 "ax2.set_title(\"Accuracy\")\n",
426 "\n",
427 "plt.tight_layout()\n",
428 "plt.show()"
429 ]
430 },
431 {
432 "cell_type": "code",
433 "execution_count": null,
434 "id": "d4a46d90-df37-4c11-91fd-1ef107d5b231",
435 "metadata": {},
436 "outputs": [],
437 "source": [
438 "def show_confusion_matrix(model, test_loader, config):\n",
439 " model.eval()\n",
440 " all_preds, all_labels = [], []\n",
441 " \n",
442 " with torch.no_grad():\n",
443 " for images, labels in tqdm(test_loader, desc=\"Evaluating\"):\n",
444 " images = images.to(config.device)\n",
445 " outputs = model(pixel_values=images)\n",
446 " _, predicted = torch.max(outputs.logits, 1)\n",
447 " \n",
448 " all_preds.extend(predicted.cpu().numpy())\n",
449 " all_labels.extend(labels.numpy())\n",
450 " \n",
451 " all_preds = np.array(all_preds)\n",
452 " all_labels = np.array(all_labels)\n",
453 " \n",
454 " print(\"Classification Report:\")\n",
455 " print(classification_report(all_labels, all_preds, target_names=config.class_names, digits=4))\n",
456 " \n",
457 " cm = confusion_matrix(all_labels, all_preds)\n",
458 " plt.figure(figsize=(8, 6))\n",
459 " sns.heatmap(cm, annot=True, fmt=\"d\", cmap=\"Blues\",\n",
460 " xticklabels=config.class_names, yticklabels=config.class_names)\n",
461 " plt.title(\"Confusion Matrix\")\n",
462 " plt.ylabel(\"True Label\")\n",
463 " plt.xlabel(\"Predicted Label\")\n",
464 " plt.tight_layout()\n",
465 " plt.show()\n",
466 " \n",
467 " accuracy = (all_preds == all_labels).sum() / len(all_labels)\n",
468 " print(f\"\\nAccuracy: {accuracy:.2%}\")\n",
469 " \n",
470 " return all_preds, all_labels\n",
471 "\n",
472 "all_preds, all_labels = show_confusion_matrix(model, test_loader, CONFIG)"
473 ]
474 },
475 {
476 "cell_type": "code",
477 "execution_count": null,
478 "id": "93836555-7fbe-4cbb-b7ac-91d3e06b1da6",
479 "metadata": {},
480 "outputs": [],
481 "source": [
482 "def review_misclassified(model, test_loader, config, top_n=10):\n",
483 " model.eval()\n",
484 " \n",
485 " false_positives = []\n",
486 " false_negatives = []\n",
487 " \n",
488 " dataset = test_loader.dataset\n",
489 " \n",
490 " with torch.no_grad():\n",
491 " sample_idx = 0\n",
492 " for images, labels in tqdm(test_loader, desc=\"Finding misclassified\"):\n",
493 " images = images.to(config.device)\n",
494 " outputs = model(pixel_values=images)\n",
495 " probs = torch.softmax(outputs.logits, dim=1)\n",
496 " _, predicted = torch.max(outputs.logits, 1)\n",
497 " \n",
498 " for i in range(len(labels)):\n",
499 " true_label = labels[i].item()\n",
500 " pred_label = predicted[i].item()\n",
501 " confidence = probs[i, pred_label].item()\n",
502 " \n",
503 " if true_label != pred_label:\n",
504 " img_path = dataset.image_paths[sample_idx]\n",
505 " \n",
506 " if true_label == 0 and pred_label == 1:\n",
507 " false_positives.append({\n",
508 " \"path\": img_path,\n",
509 " \"confidence\": confidence,\n",
510 " \"nsfw_prob\": probs[i, 1].item(),\n",
511 " })\n",
512 " elif true_label == 1 and pred_label == 0:\n",
513 " false_negatives.append({\n",
514 " \"path\": img_path,\n",
515 " \"confidence\": confidence,\n",
516 " \"sfw_prob\": probs[i, 0].item(),\n",
517 " })\n",
518 " \n",
519 " sample_idx += 1\n",
520 " \n",
521 " false_positives = sorted(false_positives, key=lambda x: x[\"confidence\"], reverse=True)\n",
522 " false_negatives = sorted(false_negatives, key=lambda x: x[\"confidence\"], reverse=True)\n",
523 " \n",
524 " if false_positives:\n",
525 " fig, axes = plt.subplots(2, 5, figsize=(15, 6))\n",
526 " axes = axes.flatten()\n",
527 " \n",
528 " for i, fp in enumerate(false_positives[:top_n]):\n",
529 " if i < len(axes):\n",
530 " img = Image.open(fp[\"path\"]).convert(\"RGB\")\n",
531 " axes[i].imshow(img)\n",
532 " axes[i].set_title(f\"NSFW prob: {fp['nsfw_prob']:.2%}\", fontsize=9)\n",
533 " axes[i].axis(\"off\")\n",
534 " \n",
535 " for i in range(len(false_positives[:top_n]), len(axes)):\n",
536 " axes[i].axis(\"off\")\n",
537 " \n",
538 " plt.suptitle(\"False Positives (SFW → NSFW)\", fontsize=14)\n",
539 " plt.tight_layout()\n",
540 " plt.show()\n",
541 " else:\n",
542 " print(\"No false positives found!\")\n",
543 " \n",
544 " if false_negatives:\n",
545 " fig, axes = plt.subplots(2, 5, figsize=(15, 6))\n",
546 " axes = axes.flatten()\n",
547 " \n",
548 " for i, fn in enumerate(false_negatives[:top_n]):\n",
549 " if i < len(axes):\n",
550 " img = Image.open(fn[\"path\"]).convert(\"RGB\")\n",
551 " axes[i].imshow(img)\n",
552 " axes[i].set_title(f\"SFW prob: {fn['sfw_prob']:.2%}\", fontsize=9)\n",
553 " axes[i].axis(\"off\")\n",
554 " \n",
555 " for i in range(len(false_negatives[:top_n]), len(axes)):\n",
556 " axes[i].axis(\"off\")\n",
557 " \n",
558 " plt.suptitle(\"False Negatives (NSFW → SFW)\", fontsize=14)\n",
559 " plt.tight_layout()\n",
560 " plt.show()\n",
561 " else:\n",
562 " print(\"No false negatives found!\")\n",
563 " \n",
564 " print(f\"\\nTotal: {len(false_positives)} false positives, {len(false_negatives)} false negatives\")\n",
565 " \n",
566 " return false_positives, false_negatives\n",
567 "\n",
568 "false_positives, false_negatives = review_misclassified(model, test_loader, CONFIG, top_n=10)"
569 ]
570 },
571 {
572 "cell_type": "code",
573 "execution_count": null,
574 "id": "69e5ce7b-fe32-4d01-ad5e-d1055f06b4ef",
575 "metadata": {},
576 "outputs": [],
577 "source": [
578 "print(\"FALSE POSITIVE PATHS:\")\n",
579 "for fp in false_positives[:10]:\n",
580 " print(f\" {fp['path']}\")\n",
581 "\n",
582 "print(\"\\nFALSE NEGATIVE PATHS:\")\n",
583 "for fn in false_negatives[:10]:\n",
584 " print(f\" {fn['path']}\")"
585 ]
586 },
587 {
588 "cell_type": "code",
589 "execution_count": null,
590 "id": "c06bcf4f-5acb-4827-8991-e85b9b7b7fbd",
591 "metadata": {},
592 "outputs": [],
593 "source": [
594 "def save_model(model, processor, path=\"./nsfw_classifier\"):\n",
595 " Path(path).mkdir(parents=True, exist_ok=True)\n",
596 " model.save_pretrained(path)\n",
597 " processor.save_pretrained(path)\n",
598 " print(f\"Model saved to {path}\")\n",
599 "\n",
600 "# save_model(model, processor, CONFIG.model_save_path)"
601 ]
602 },
603 {
604 "cell_type": "code",
605 "execution_count": null,
606 "id": "50fbe061-0435-4b0b-ab66-0bb773ba6ab0",
607 "metadata": {},
608 "outputs": [],
609 "source": [
610 "def load_model(path=\"./nsfw_classifier\"):\n",
611 " model = ViTForImageClassification.from_pretrained(path).to(CONFIG.device)\n",
612 " processor = ViTImageProcessor.from_pretrained(path)\n",
613 " model.eval()\n",
614 " print(f\"Model loaded from {path}\")\n",
615 " return model, processor\n",
616 "\n",
617 "# model, processor = load_model(CONFIG.model_save_path)"
618 ]
619 },
620 {
621 "cell_type": "code",
622 "execution_count": null,
623 "id": "50fe4bec-c29e-453b-beca-16304c1c7a1a",
624 "metadata": {},
625 "outputs": [],
626 "source": [
627 "import requests\n",
628 "from io import BytesIO\n",
629 "\n",
630 "def predict_url(url):\n",
631 " model.eval()\n",
632 "\n",
633 " response = requests.get(url)\n",
634 " response.raise_for_status()\n",
635 " image = Image.open(BytesIO(response.content)).convert(\"RGB\")\n",
636 " \n",
637 " inputs = processor(images=image, return_tensors=\"pt\")\n",
638 " with torch.no_grad():\n",
639 " outputs = model(pixel_values=inputs[\"pixel_values\"].to(CONFIG.device))\n",
640 " probs = torch.softmax(outputs.logits, dim=1)[0]\n",
641 " \n",
642 " plt.figure(figsize=(6, 6))\n",
643 " plt.imshow(image)\n",
644 " plt.title(f\"SFW: {probs[0]:.2%} | NSFW: {probs[1]:.2%}\")\n",
645 " plt.axis(\"off\")\n",
646 " plt.show()\n",
647 " \n",
648 " return {\"sfw\": probs[0].item(), \"nsfw\": probs[1].item()}\n",
649 "\n",
650 "url_to_predict = \"\"\"\n",
651 "https://cdn.bsky.app/img/feed_thumbnail/plain/did:plc:j55kvyxp44daiz4yvx4g4ari/bafkreiebtixeda6jgde6gtsbffza7yxozijbisewev7i5ykbyw7m5rsboi@jpeg\n",
652 "\"\"\"\n",
653 "\n",
654 "predict_url(url_to_predict.strip())"
655 ]
656 }
657 ],
658 "metadata": {
659 "kernelspec": {
660 "display_name": "Python 3 (ipykernel)",
661 "language": "python",
662 "name": "python3"
663 },
664 "language_info": {
665 "codemirror_mode": {
666 "name": "ipython",
667 "version": 3
668 },
669 "file_extension": ".py",
670 "mimetype": "text/x-python",
671 "name": "python",
672 "nbconvert_exporter": "python",
673 "pygments_lexer": "ipython3",
674 "version": "3.12.7"
675 }
676 },
677 "nbformat": 4,
678 "nbformat_minor": 5
679}