this repo has no description
at main 22 kB view raw
1import logging 2import random 3from io import BytesIO 4from pathlib import Path 5from typing import Any, Dict, List, Optional, Set, Tuple, Union 6 7import click 8import matplotlib.pyplot as plt 9import numpy as np 10import requests 11import seaborn as sns 12import torch.nn as nn 13import torch.optim as optim 14from PIL import Image 15from sklearn.metrics import classification_report, confusion_matrix 16from torch import torch 17from torch.utils.data import DataLoader, Dataset, random_split 18from torchvision import transforms 19from tqdm import tqdm 20from transformers import ViTForImageClassification, ViTImageProcessor 21 22from config import CONFIG, Config 23 24logging.basicConfig( 25 level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 26) 27logger = logging.getLogger() 28 29 30class NSFWDataset(Dataset[torch.Tensor]): 31 def __init__( 32 self, 33 image_paths: List[str], 34 labels: List[int], 35 processor: ViTImageProcessor, 36 augment: bool = False, 37 ): 38 self.image_paths = image_paths 39 self.labels = labels 40 self.processor = processor 41 42 if augment: 43 self.transform = transforms.Compose( 44 [ 45 transforms.RandomHorizontalFlip(p=0.5), 46 transforms.RandomRotation(15), 47 transforms.ColorJitter( 48 brightness=0.2, contrast=0.2, saturation=0.2 49 ), 50 transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), 51 ] 52 ) 53 else: 54 self.transform = transforms.Compose( 55 [ 56 transforms.Resize((224, 224)), 57 ] 58 ) 59 60 def __len__(self) -> int: 61 return len(self.image_paths) 62 63 def __getitem__(self, idx: int) -> Optional[Tuple[torch.Tensor, int]]: 64 img_path = self.image_paths[idx] 65 66 try: 67 image: Image.Image = Image.open(img_path).convert("RGB") 68 image = self.transform(image) 69 70 processed = self.processor(images=image, return_tensors="pt") 71 72 pixel_values = processed["pixel_values"].squeeze(0) 73 74 label: int = self.labels[idx] 75 76 return pixel_values, label 77 except Exception as e: 78 logger.error(f"Error loading image {img_path}: {e}") 79 80 81def load_image_paths_and_labels( 82 nsfw_dir: str, 83 sfw_human_dir: str, 84 sfw_anime_dir: str, 85) -> Tuple[List[str], List[int]]: 86 nsfw_paths: List[str] = [] 87 sfw_paths: List[str] = [] 88 89 # false positives for nsfw 90 override_paths: List[str] = [] 91 92 valid_extensions: Set[str] = {".jpg", ".jpeg", ".png", ".webp"} 93 94 logger.info(f"Loading NSFW images from {nsfw_dir}") 95 nsfw_path = Path(nsfw_dir) 96 if nsfw_path.exists(): 97 for f in nsfw_path.iterdir(): 98 if f.suffix.lower() in valid_extensions: 99 nsfw_paths.append(str(f)) 100 else: 101 logger.warning(f"NSFW directory not found: {nsfw_dir}") 102 103 logger.info(f"Loading SFW human images from {sfw_human_dir}") 104 sfw_humans_path = Path(sfw_human_dir) 105 if sfw_humans_path.exists(): 106 for f in sfw_humans_path.iterdir(): 107 if f.suffix.lower() in valid_extensions: 108 sfw_paths.append(str(f)) 109 else: 110 logger.warning(f"SFW humans directory not found: {sfw_human_dir}") 111 112 logger.info(f"Loading SFW anime images from {sfw_anime_dir}") 113 sfw_anime_path = Path(sfw_anime_dir) 114 if sfw_anime_path.exists(): 115 for f in sfw_anime_path.iterdir(): 116 if f.suffix.lower() in valid_extensions: 117 sfw_paths.append(str(f)) 118 else: 119 logger.warning(f"SFW anime directory not found: {sfw_anime_dir}") 120 121 nsfw_count = len(nsfw_paths) 122 sfw_count = len(sfw_paths) 123 124 logger.info("Loading overrides (NSFW false positives)") 125 override_path = Path("./dataset_clean/overrides") 126 if override_path.exists(): 127 for f in override_path.iterdir(): 128 if f.suffix.lower() in valid_extensions: 129 override_paths.append(str(f)) 130 else: 131 logger.warning("No override path found") 132 133 logger.info(f"Dataset loaded before balancing: {nsfw_count} NSFW, {sfw_count} SFW") 134 135 if nsfw_count > 0 and sfw_count > 0: 136 current_ratio = nsfw_count / sfw_count 137 logger.info(f"Current NSFW/SFW ratio: {current_ratio:.2f}") 138 139 needs_balancing = current_ratio < 0.5 or current_ratio > 2.0 140 141 target_ratio = 1.2 142 143 if needs_balancing: 144 if current_ratio > target_ratio: 145 target_nsfw = int(sfw_count * target_ratio) 146 target_sfw = sfw_count 147 logger.info(f"Downsampling NSFW: {nsfw_count}{target_nsfw}") 148 149 random.seed(42) 150 nsfw_paths = random.sample(nsfw_paths, target_nsfw) 151 else: 152 target_nsfw = nsfw_count 153 target_sfw = int(nsfw_count / target_ratio) 154 logger.info(f"Downsampling SFW: {sfw_count}{target_sfw}") 155 156 random.seed(42) 157 sfw_paths = random.sample(sfw_paths, target_sfw) 158 159 nsfw_count = len(nsfw_paths) 160 sfw_count = len(sfw_paths) 161 new_ratio = nsfw_count / sfw_count 162 163 logger.info(f"Dataset after balancing: {nsfw_count} NSFW, {sfw_count} SFW") 164 logger.info(f"New NSFW/SFW ratio: {new_ratio:.2f}") 165 166 image_paths: List[str] = [] 167 labels: List[int] = [] 168 169 for path in nsfw_paths: 170 image_paths.append(path) 171 labels.append(1) 172 173 for path in sfw_paths: 174 image_paths.append(path) 175 labels.append(0) 176 177 # Add in any overrides so they will show up even after balancing 178 for path in override_paths: 179 image_paths.append(path) 180 labels.append(0) 181 182 combined = list(zip(image_paths, labels)) 183 random.shuffle(combined) 184 image_paths, labels = zip(*combined) 185 image_paths = list(image_paths) 186 labels = list(labels) 187 188 logger.info(f"Final dataset: {len(image_paths)} total images") 189 190 return image_paths, labels 191 192 193def create_dataloaders( 194 image_paths: List[str], 195 labels: List[int], 196 processor: ViTImageProcessor, 197 config: Config, 198 augment: bool, 199) -> Tuple[Dataset[torch.Tensor], Dataset[torch.Tensor], Dataset[torch.Tensor]]: 200 total_size = len(image_paths) 201 train_size = int(config.train_ratio * total_size) 202 val_size = int(config.val_ratio * total_size) 203 test_size = total_size - train_size - val_size 204 205 # create our splits without actually augmenting anything 206 full_dataset = NSFWDataset(image_paths, labels, processor, augment=False) 207 208 train_dataset_temp, val_dataset_temp, test_dataset_temp = random_split( 209 full_dataset, 210 [train_size, val_size, test_size], 211 generator=torch.Generator().manual_seed(42), 212 ) 213 214 train_indicies = train_dataset_temp.indices 215 val_indicies = val_dataset_temp.indices 216 test_indicies = test_dataset_temp.indices 217 218 train_paths = [image_paths[i] for i in train_indicies] 219 train_labels = [labels[i] for i in train_indicies] 220 train_dataset = NSFWDataset(train_paths, train_labels, processor, augment) 221 222 val_paths = [image_paths[i] for i in val_indicies] 223 val_labels = [labels[i] for i in val_indicies] 224 val_dataset = NSFWDataset(val_paths, val_labels, processor, augment=False) 225 226 test_paths = [image_paths[i] for i in test_indicies] 227 test_labels = [labels[i] for i in test_indicies] 228 test_dataset = NSFWDataset(test_paths, test_labels, processor, augment=False) 229 230 pin_memory = config.device.type == "cuda" 231 232 train_loader = DataLoader[torch.Tensor]( 233 train_dataset, 234 batch_size=config.batch_size, 235 shuffle=True, 236 num_workers=config.num_workers, 237 pin_memory=pin_memory, 238 ) 239 240 val_loader = DataLoader[torch.Tensor]( 241 val_dataset, 242 batch_size=config.batch_size, 243 shuffle=False, 244 num_workers=config.num_workers, 245 pin_memory=pin_memory, 246 ) 247 248 test_loader = DataLoader[torch.Tensor]( 249 test_dataset, 250 batch_size=config.batch_size, 251 shuffle=False, 252 num_workers=config.num_workers, 253 pin_memory=pin_memory, 254 ) 255 256 return train_loader, val_loader, test_loader 257 258 259def create_model(config: Config) -> ViTForImageClassification: 260 logger.info(f"Loading pretrained ViT model: {config.model_name}") 261 262 model = ViTForImageClassification.from_pretrained( 263 config.model_name, 264 num_labels=config.num_classes, 265 ignore_mismatched_sizes=True, 266 ) 267 268 model = model.to(config.device) 269 270 return model 271 272 273def train_epoch( 274 model: ViTForImageClassification, 275 train_loader: DataLoader[torch.Tensor], 276 criterion: nn.Module, 277 optimizer: optim.Optimizer, 278 device: torch.device, 279): 280 model.train() 281 running_loss: float = 0.0 282 correct: int = 0 283 total: int = 0 284 285 progress_bar = tqdm(train_loader, desc="Training") 286 287 for batch_idx, (images, labels) in enumerate(progress_bar): 288 images: torch.Tensor = images.to(device) 289 labels: torch.Tensor = labels.to(device) 290 291 optimizer.zero_grad() 292 293 outputs = model(pixel_values=images) 294 logits: torch.Tensor = outputs.logits 295 296 loss: torch.Tensor = criterion(logits, labels) 297 298 loss.backward() 299 optimizer.step() 300 301 running_loss += loss.item() 302 _, predicted = torch.max(logits, 1) 303 total += labels.size(0) 304 correct += int((predicted == labels).sum().item()) 305 306 progress_bar.set_postfix( 307 {"loss": running_loss / (batch_idx + 1), "acc": 100.0 * correct / total} 308 ) 309 310 epoch_loss: float = running_loss / len(train_loader) 311 epoch_acc: float = 100.0 * correct / total 312 313 return epoch_loss, epoch_acc 314 315 316def validate( 317 model: ViTForImageClassification, 318 val_loader: DataLoader[torch.Tensor], 319 criterion: nn.Module, 320 device: torch.device, 321): 322 model.eval() 323 running_loss: float = 0.0 324 correct: int = 0 325 total: int = 0 326 327 all_preds: List[int] = [] 328 all_labels: List[int] = [] 329 330 with torch.no_grad(): 331 progress_bar = tqdm(val_loader, desc="Validation") 332 333 for images, labels in progress_bar: 334 images: torch.Tensor = images.to(device) 335 labels: torch.Tensor = labels.to(device) 336 337 outputs = model(pixel_values=images) 338 logits: torch.Tensor = outputs.logits 339 340 loss: torch.Tensor = criterion(logits, labels) 341 342 running_loss += loss.item() 343 _, predicted = torch.max(logits, 1) 344 total += labels.size(0) 345 correct += int((predicted == labels).sum().item()) 346 347 all_preds.extend(predicted.cpu().numpy().tolist()) 348 all_labels.extend(labels.cpu().numpy().tolist()) 349 350 progress_bar.set_postfix( 351 {"loss": running_loss / len(val_loader), "acc": 100.0 * correct / total} 352 ) 353 354 epoch_loss: float = running_loss / len(val_loader) 355 epoch_acc: float = 100.0 * correct / total 356 357 return epoch_loss, epoch_acc, all_preds, all_labels 358 359 360def train_model( 361 model: ViTForImageClassification, 362 train_loader: DataLoader[torch.Tensor], 363 val_loader: DataLoader[torch.Tensor], 364 config: Config, 365): 366 criterion = nn.CrossEntropyLoss() 367 optimizer = optim.AdamW( 368 model.parameters(), 369 lr=config.learning_rate, 370 weight_decay=config.weight_decay, 371 ) 372 scheduler = optim.lr_scheduler.ReduceLROnPlateau( 373 optimizer, 374 mode="min", 375 factor=0.5, 376 patience=2, 377 ) 378 379 best_val_acc = 0.0 380 best_model_state: Optional[Dict[str, Any]] = None 381 382 history: Dict[str, List[float]] = { 383 "train_loss": [], 384 "train_acc": [], 385 "val_loss": [], 386 "val_acc": [], 387 } 388 389 for epoch in range(config.num_epochs): 390 logger.info(f"Epoch {epoch + 1}/{config.num_epochs}") 391 print("-" * 60) 392 393 train_loss, train_acc = train_epoch( 394 model, 395 train_loader, 396 criterion, 397 optimizer, 398 config.device, 399 ) 400 401 val_loss, val_acc, _, _ = validate(model, val_loader, criterion, config.device) 402 403 scheduler.step(val_loss) 404 405 history["train_loss"].append(train_loss) 406 history["train_acc"].append(train_acc) 407 history["val_loss"].append(val_loss) 408 history["val_acc"].append(val_acc) 409 410 print(f"\nEpoch {epoch + 1} Summary:") 411 print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}") 412 print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}") 413 414 if val_acc > best_val_acc: 415 best_val_acc = val_acc 416 best_model_state = model.state_dict().copy() 417 print(f"New best model! (Val Acc: {val_acc:2f}%)") 418 419 if best_model_state is not None: 420 model.load_state_dict(best_model_state) 421 422 logger.info(f"Training complete. Best validation accuracy: {best_val_acc:.2f}%") 423 424 return model, history 425 426 427def save_model( 428 model: ViTForImageClassification, 429 processor: ViTImageProcessor, 430 config: Config, 431): 432 Path(config.model_save_path).mkdir(parents=True, exist_ok=True) 433 434 model.save_pretrained(config.model_save_path) 435 processor.save_pretrained(config.model_save_path) 436 437 print(f"Model saved to {config.model_save_path}") 438 439 440def load_trained_model( 441 model_path: str, device: torch.device 442) -> Tuple[ViTForImageClassification, ViTImageProcessor]: 443 model = ViTForImageClassification.from_pretrained(model_path) 444 processor = ViTImageProcessor.from_pretrained(model_path) 445 model = model.to(device) 446 model.eval() 447 return model, processor 448 449 450def predict_image( 451 image_path: str, 452 model: ViTForImageClassification, 453 processor: ViTImageProcessor, 454 device: torch.device, 455 config: Config, 456) -> Tuple[str, float]: 457 model.eval() 458 459 image: Image.Image = Image.open(image_path).convert("RGB") 460 inputs: Dict[str, torch.Tensor] = processor(images=image, return_tensors="pt") 461 pixel_values: torch.Tensor = inputs["pixel_values"].to(device) 462 463 with torch.no_grad(): 464 outputs = model(pixel_values=pixel_values) 465 logits: torch.Tensor = outputs.logits 466 probs: torch.Tensor = torch.softmax(logits, dim=1) 467 predicted_class: int = torch.argmax(probs, dim=1).item() 468 confidence: float = probs[0, predicted_class].item() 469 470 prediction: str = config.class_names[predicted_class] 471 472 return prediction, confidence 473 474 475def evaluate_model( 476 model: ViTForImageClassification, 477 test_loader: DataLoader[torch.Tensor], 478 config: Config, 479) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: 480 model.eval() 481 all_preds: List[int] = [] 482 all_labels: List[int] = [] 483 all_probs: List[List[float]] = [] 484 485 print("\n" + "=" * 60) 486 print("Evaluating on Test Set") 487 print("=" * 60) 488 489 with torch.no_grad(): 490 for images, labels in tqdm(test_loader, desc="Testing"): 491 images: torch.Tensor = images.to(config.device) 492 493 outputs = model(pixel_values=images) 494 logits: torch.Tensor = outputs.logits 495 496 probs: torch.Tensor = torch.softmax(logits, dim=1) 497 _, predicted = torch.max(logits, 1) 498 499 all_preds.extend(predicted.cpu().numpy().tolist()) 500 all_labels.extend(labels.numpy().tolist()) 501 all_probs.extend(probs.cpu().numpy().tolist()) 502 503 all_preds_array: np.ndarray = np.array(all_preds) 504 all_labels_array: np.ndarray = np.array(all_labels) 505 all_probs_array: np.ndarray = np.array(all_probs) 506 507 print("\nClassification Report:") 508 print( 509 classification_report( 510 all_labels_array, all_preds_array, target_names=config.class_names, digits=4 511 ) 512 ) 513 514 cm: np.ndarray = confusion_matrix(all_labels_array, all_preds_array) 515 516 plt.figure(figsize=(8, 6)) 517 sns.heatmap( 518 cm, 519 annot=True, 520 fmt="d", 521 cmap="Blues", 522 xticklabels=config.class_names, 523 yticklabels=config.class_names, 524 ) 525 plt.title("Confusion Matrix") 526 plt.ylabel("True Label") 527 plt.xlabel("Predicted Label") 528 plt.tight_layout() 529 plt.savefig("confusion_matrix.png", dpi=300, bbox_inches="tight") 530 print("\n✓ Confusion matrix saved as 'confusion_matrix.png'") 531 532 return all_preds_array, all_labels_array, all_probs_array 533 534 535def plot_training_history(history: Dict[str, List[float]]) -> None: 536 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) 537 538 ax1.plot(history["train_loss"], label="Train Loss", marker="o") 539 ax1.plot(history["val_loss"], label="Val Loss", marker="s") 540 ax1.set_xlabel("Epoch") 541 ax1.set_ylabel("Loss") 542 ax1.set_title("Training and Validation Loss") 543 ax1.legend() 544 ax1.grid(True) 545 546 ax2.plot(history["train_acc"], label="Train Acc", marker="o") 547 ax2.plot(history["val_acc"], label="Val Acc", marker="s") 548 ax2.set_xlabel("Epoch") 549 ax2.set_ylabel("Accuracy (%)") 550 ax2.set_title("Training and Validation Accuracy") 551 ax2.legend() 552 ax2.grid(True) 553 554 plt.tight_layout() 555 plt.savefig("training_history.png", dpi=300, bbox_inches="tight") 556 print("✓ Training history saved as 'training_history.png'") 557 558 559class NSFWClassifier: 560 def __init__(self, model_path: str, device: Optional[torch.device] = None) -> None: 561 if device is None: 562 self.device: torch.device = torch.device( 563 "cuda" if torch.cuda.is_available() else "cpu" 564 ) 565 else: 566 self.device = device 567 568 print(f"Loading model from {model_path}...") 569 self.model: ViTForImageClassification = ( 570 ViTForImageClassification.from_pretrained(model_path) 571 ) 572 self.processor: ViTImageProcessor = ViTImageProcessor.from_pretrained( 573 model_path 574 ) 575 self.model = self.model.to(self.device) 576 self.model.eval() 577 578 self.class_names: List[str] = ["SFW", "NSFW"] 579 print(f"✓ Model loaded on {self.device}") 580 581 def predict_single(self, image_path: str) -> Dict[str, Union[str, float]]: 582 image: Image.Image = Image.open(image_path).convert("RGB") 583 584 inputs: Dict[str, torch.Tensor] = self.processor( 585 images=image, return_tensors="pt" 586 ) 587 pixel_values: torch.Tensor = inputs["pixel_values"].to(self.device) 588 589 with torch.no_grad(): 590 outputs = self.model(pixel_values=pixel_values) 591 logits: torch.Tensor = outputs.logits 592 probs: torch.Tensor = torch.softmax(logits, dim=1)[0] 593 predicted_class: int = torch.argmax(probs).item() 594 595 return { 596 "prediction": self.class_names[predicted_class], 597 "confidence": probs[predicted_class].item(), 598 "sfw_prob": probs[0].item(), 599 "nsfw_prob": probs[1].item(), 600 } 601 602 def predict_single_bytes(self, image_bytes: bytes) -> Dict[str, Union[str, float]]: 603 image = Image.open(BytesIO(image_bytes)).convert("RGB") 604 inputs: Dict[str, torch.Tensor] = self.processor( 605 images=image, return_tensors="pt" 606 ) 607 pixel_values: torch.Tensor = inputs["pixel_values"].to(self.device) 608 609 with torch.no_grad(): 610 outputs = self.model(pixel_values=pixel_values) 611 logits: torch.Tensor = outputs.logits 612 probs: torch.Tensor = torch.softmax(logits, dim=1)[0] 613 predicted_class: int = torch.argmax(probs).item() 614 615 return { 616 "prediction": self.class_names[predicted_class], 617 "confidence": probs[predicted_class].item(), 618 "sfw_prob": probs[0].item(), 619 "nsfw_prob": probs[1].item(), 620 } 621 622 623@click.group() 624def cli(): 625 pass 626 627 628@cli.command() 629@click.option( 630 "--augment", 631 type=bool, 632 default=False, 633 required=False, 634) 635def train( 636 augment: Optional[bool], 637): 638 config = CONFIG 639 640 print("=" * 60) 641 print("NSFW Image Classifier - ViT Finetuner") 642 print("=" * 60) 643 print(f"Device: {config.device}") 644 print(f"Model: {config.model_name}") 645 print(f"Batch size: {config.batch_size}") 646 print(f"Learning rate: {config.learning_rate}") 647 print(f"Epochs: {config.num_epochs}") 648 649 logger.info("Loading image processor") 650 processor = ViTImageProcessor.from_pretrained(config.model_name) 651 652 logger.info("Creating dataset") 653 image_paths, labels = load_image_paths_and_labels( 654 config.nsfw_dir, 655 config.sfw_humans_dir, 656 config.sfw_anime_dir, 657 ) 658 659 logger.info("Creating dataloaders") 660 train_loader, val_loader, test_loader = create_dataloaders( 661 image_paths, labels, processor, config, augment=augment or False 662 ) 663 664 model = create_model(config) 665 666 model, history = train_model(model, train_loader, val_loader, config) 667 668 logger.info("Evaluating model") 669 evaluate_model(model, test_loader, config) 670 671 logger.info("Generating plots") 672 plot_training_history(history) 673 674 logger.info("Saving model") 675 save_model(model, processor, config) 676 677 678@cli.command() 679@click.argument("image", type=str, required=True) 680def predict( 681 image: str, 682): 683 classifer = NSFWClassifier( 684 CONFIG.model_save_path, 685 CONFIG.device, 686 ) 687 688 print(f"{classifer.predict_single(image)}") 689 690 691@cli.command() 692@click.argument("url", type=str, required=True) 693def predict_url( 694 url: str, 695): 696 classifier = NSFWClassifier( 697 CONFIG.model_save_path, 698 CONFIG.device, 699 ) 700 701 resp = requests.get(url) 702 resp.raise_for_status() 703 704 # we can use predict_bytes_single, but we wont so that we can do something with this image like 705 # move it to a directory for retraining 706 with open("./testimg.jpeg", "wb") as f: 707 f.write(resp.content) 708 709 print(f"{classifier.predict_single('./testimg.jpeg')}") 710 711 712if __name__ == "__main__": 713 cli() 714