this repo has no description

change model

Changed files
+38 -191
+3 -3
config.py
··· 11 11 12 12 model_save_path: str = "./models/nsfw_vit_classifier" 13 13 14 - model_name: str = "google/vit-base-patch16-224" 14 + model_name: str = "Falconsai/nsfw_image_detection" 15 15 num_classes: int = 2 16 16 17 - batch_size: int = 32 17 + batch_size: int = 16 18 18 num_epochs: int = 10 19 - learning_rate: float = 2e-5 19 + learning_rate: float = 5e-6 20 20 weight_decay: float = 0.01 21 21 22 22 train_ratio: float = 0.7
+35 -188
main.py
··· 1 1 import logging 2 - import shutil 3 2 import random 3 + from io import BytesIO 4 4 from pathlib import Path 5 5 from typing import Any, Dict, List, Optional, Set, Tuple, Union 6 6 ··· 43 43 self.transform = transforms.Compose( 44 44 [ 45 45 transforms.RandomHorizontalFlip(p=0.5), 46 - transforms.RandomRotation(20), 46 + transforms.RandomRotation(15), 47 47 transforms.ColorJitter( 48 - brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1 48 + brightness=0.2, contrast=0.2, saturation=0.2 49 49 ), 50 - transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), 51 - transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), 52 - transforms.RandomGrayscale(p=0.1), 50 + transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), 53 51 ] 54 52 ) 55 53 else: ··· 97 95 nsfw_path = Path(nsfw_dir) 98 96 if nsfw_path.exists(): 99 97 for f in nsfw_path.iterdir(): 100 - if f.suffix.lower() in valid_extensions or f.name.startswith("did:"): 98 + if f.suffix.lower() in valid_extensions: 101 99 nsfw_paths.append(str(f)) 102 100 else: 103 101 logger.warning(f"NSFW directory not found: {nsfw_dir}") ··· 106 104 sfw_humans_path = Path(sfw_human_dir) 107 105 if sfw_humans_path.exists(): 108 106 for f in sfw_humans_path.iterdir(): 109 - if f.suffix.lower() in valid_extensions or f.name.startswith("did:"): 107 + if f.suffix.lower() in valid_extensions: 110 108 sfw_paths.append(str(f)) 111 109 else: 112 110 logger.warning(f"SFW humans directory not found: {sfw_human_dir}") ··· 115 113 sfw_anime_path = Path(sfw_anime_dir) 116 114 if sfw_anime_path.exists(): 117 115 for f in sfw_anime_path.iterdir(): 118 - if f.suffix.lower() in valid_extensions or f.name.startswith("did:"): 116 + if f.suffix.lower() in valid_extensions: 119 117 sfw_paths.append(str(f)) 120 118 else: 121 119 logger.warning(f"SFW anime directory not found: {sfw_anime_dir}") ··· 143 141 target_ratio = 1.2 144 142 145 143 if needs_balancing: 146 - logger.info(f"🔄 Auto-balancing to target ratio: {target_ratio:.2f}") 147 - 148 144 if current_ratio > target_ratio: 149 145 target_nsfw = int(sfw_count * target_ratio) 150 146 target_sfw = sfw_count ··· 164 160 sfw_count = len(sfw_paths) 165 161 new_ratio = nsfw_count / sfw_count 166 162 167 - logger.info( 168 - f"✅ Dataset after balancing: {nsfw_count} NSFW, {sfw_count} SFW" 169 - ) 163 + logger.info(f"Dataset after balancing: {nsfw_count} NSFW, {sfw_count} SFW") 170 164 logger.info(f"New NSFW/SFW ratio: {new_ratio:.2f}") 171 165 172 - else: 173 - logger.info("✅ Dataset is already balanced, no adjustment needed") 174 - 175 - elif nsfw_count == 0: 176 - logger.error("❌ No NSFW images found! Check your data directory.") 177 - elif sfw_count == 0: 178 - logger.error("❌ No SFW images found! Check your data directory.") 179 - 180 - # Combine and create labels 181 166 image_paths: List[str] = [] 182 167 labels: List[int] = [] 183 168 184 - # Add NSFW (label = 1) 185 169 for path in nsfw_paths: 186 170 image_paths.append(path) 187 171 labels.append(1) 188 172 189 - # Add SFW (label = 0) 190 173 for path in sfw_paths: 191 174 image_paths.append(path) 192 175 labels.append(0) ··· 196 179 image_paths.append(path) 197 180 labels.append(0) 198 181 199 - # Shuffle to mix NSFW and SFW 200 182 combined = list(zip(image_paths, labels)) 201 183 random.shuffle(combined) 202 184 image_paths, labels = zip(*combined) 203 185 image_paths = list(image_paths) 204 186 labels = list(labels) 205 187 206 - logger.info(f"📊 Final dataset: {len(image_paths)} total images") 188 + logger.info(f"Final dataset: {len(image_paths)} total images") 207 189 208 190 return image_paths, labels 209 191 ··· 283 265 ignore_mismatched_sizes=True, 284 266 ) 285 267 286 - for param in model.vit.embeddings.parameters(): 287 - param.requires_grad = False 288 - for param in model.vit.encoder.layer[:6].parameters(): 289 - param.requires_grad = False 290 - 291 268 model = model.to(config.device) 292 269 293 270 return model ··· 314 291 optimizer.zero_grad() 315 292 316 293 outputs = model(pixel_values=images) 317 - logits: torch.Tensor = outputs.logits # Shape: (batch_size, num_classes) 294 + logits: torch.Tensor = outputs.logits 318 295 319 296 loss: torch.Tensor = criterion(logits, labels) 320 297 ··· 386 363 val_loader: DataLoader[torch.Tensor], 387 364 config: Config, 388 365 ): 389 - criterion = nn.CrossEntropyLoss(label_smoothing=0.1) 366 + criterion = nn.CrossEntropyLoss() 390 367 optimizer = optim.AdamW( 391 368 model.parameters(), 392 369 lr=config.learning_rate, ··· 431 408 history["val_acc"].append(val_acc) 432 409 433 410 print(f"\nEpoch {epoch + 1} Summary:") 434 - print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}") 435 - print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}") 411 + print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}") 412 + print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}") 436 413 437 414 if val_acc > best_val_acc: 438 415 best_val_acc = val_acc ··· 622 599 "nsfw_prob": probs[1].item(), 623 600 } 624 601 625 - def predict_batch( 626 - self, folder_path: str, extensions: Optional[List[str]] = None 627 - ) -> List[Dict[str, Union[str, float]]]: 628 - if extensions is None: 629 - extensions = [".jpg", ".jpeg", ".png", ".bmp", ".webp"] 630 - 631 - folder: Path = Path(folder_path) 632 - results: List[Dict[str, Union[str, float]]] = [] 633 - 634 - # Get all image files 635 - image_files: List[Path] = [] 636 - for ext in extensions: 637 - image_files.extend(folder.glob(f"*{ext}")) 638 - image_files.extend(folder.glob(f"*{ext.upper()}")) 639 - 640 - print(f"Found {len(image_files)} images") 641 - 642 - for img_path in image_files: 643 - try: 644 - result: Dict[str, Union[str, float]] = self.predict_single( 645 - str(img_path) 646 - ) 647 - result["filename"] = img_path.name 648 - results.append(result) 649 - 650 - print( 651 - f" {img_path.name}: {result['prediction']} " 652 - f"({result['confidence']:.2%})" 653 - ) 654 - except Exception as e: 655 - print(f" Error processing {img_path.name}: {e}") 656 - 657 - return results 658 - 659 - def predict_with_threshold( 660 - self, image_path: str, nsfw_threshold: float = 0.5 661 - ) -> Dict[str, Union[str, float]]: 662 - result: Dict[str, Union[str, float]] = self.predict_single(image_path) 602 + def predict_single_bytes(self, image_bytes: bytes) -> Dict[str, Union[str, float]]: 603 + image = Image.open(BytesIO(image_bytes)).convert("RGB") 604 + inputs: Dict[str, torch.Tensor] = self.processor( 605 + images=image, return_tensors="pt" 606 + ) 607 + pixel_values: torch.Tensor = inputs["pixel_values"].to(self.device) 663 608 664 - if result["nsfw_prob"] >= nsfw_threshold: 665 - result["prediction"] = "NSFW" 666 - else: 667 - result["prediction"] = "SFW" 609 + with torch.no_grad(): 610 + outputs = self.model(pixel_values=pixel_values) 611 + logits: torch.Tensor = outputs.logits 612 + probs: torch.Tensor = torch.softmax(logits, dim=1)[0] 613 + predicted_class: int = torch.argmax(probs).item() 668 614 669 - return result 670 - 671 - 672 - def save_misclassified( 673 - model: ViTForImageClassification, 674 - test_loader: DataLoader[torch.Tensor], 675 - config: Config, 676 - output_dir: str = "misclassified", 677 - ) -> None: 678 - model.eval() 679 - 680 - fp_dir = Path(output_dir) / "false_positives" 681 - fn_dir = Path(output_dir) / "false_negatives" 682 - fp_dir.mkdir(parents=True, exist_ok=True) 683 - fn_dir.mkdir(parents=True, exist_ok=True) 684 - 685 - dataset = test_loader.dataset 686 - 687 - false_positives = [] 688 - false_negatives = [] 689 - 690 - with torch.no_grad(): 691 - sample_idx = 0 692 - for images, labels in tqdm(test_loader, desc="Finding misclassified"): 693 - images = images.to(config.device) 694 - 695 - outputs = model(pixel_values=images) 696 - logits = outputs.logits 697 - probs = torch.softmax(logits, dim=1) 698 - _, predicted = torch.max(logits, 1) 699 - 700 - for i in range(len(labels)): 701 - true_label = labels[i].item() 702 - pred_label = predicted[i].item() 703 - confidence = probs[i, pred_label].item() 704 - 705 - if true_label != pred_label: 706 - img_path = dataset.image_paths[sample_idx] 707 - 708 - if true_label == 0 and pred_label == 1: 709 - false_positives.append( 710 - { 711 - "path": img_path, 712 - "confidence": confidence, 713 - "nsfw_prob": probs[i, 1].item(), 714 - } 715 - ) 716 - elif true_label == 1 and pred_label == 0: 717 - false_negatives.append( 718 - { 719 - "path": img_path, 720 - "confidence": confidence, 721 - "sfw_prob": probs[i, 0].item(), 722 - } 723 - ) 724 - 725 - sample_idx += 1 726 - 727 - import shutil 728 - 729 - for i, fp in enumerate(false_positives): 730 - src = Path(fp["path"]) 731 - dst = fp_dir / f"{fp['confidence']:.3f}_{src.name}" 732 - shutil.copy(src, dst) 733 - 734 - for i, fn in enumerate(false_negatives): 735 - src = Path(fn["path"]) 736 - dst = fn_dir / f"{fn['confidence']:.3f}_{src.name}" 737 - shutil.copy(src, dst) 738 - 739 - print(f"\n{'=' * 60}") 740 - print("Misclassification Analysis") 741 - print(f"{'=' * 60}") 742 - print(f"False Positives (SFW → NSFW): {len(false_positives)}") 743 - print(f"False Negatives (NSFW → SFW): {len(false_negatives)}") 744 - print(f"\nImages saved to: {output_dir}/") 745 - print(f"- false_positives/ ({len(false_positives)} images)") 746 - print(f"- false_negatives/ ({len(false_negatives)} images)") 747 - 748 - report_path = Path(output_dir) / "report.txt" 749 - with open(report_path, "w") as f: 750 - f.write("FALSE POSITIVES (SFW classified as NSFW)\n") 751 - f.write("=" * 50 + "\n") 752 - for fp in sorted(false_positives, key=lambda x: x["confidence"], reverse=True): 753 - f.write(f"{fp['confidence']:.3f} NSFW prob | {fp['path']}\n") 754 - 755 - f.write("\n\nFALSE NEGATIVES (NSFW classified as SFW)\n") 756 - f.write("=" * 50 + "\n") 757 - for fn in sorted(false_negatives, key=lambda x: x["confidence"], reverse=True): 758 - f.write(f"{fn['confidence']:.3f} SFW prob | {fn['path']}\n") 615 + return { 616 + "prediction": self.class_names[predicted_class], 617 + "confidence": probs[predicted_class].item(), 618 + "sfw_prob": probs[0].item(), 619 + "nsfw_prob": probs[1].item(), 620 + } 759 621 760 622 761 623 @click.group() ··· 803 665 804 666 model, history = train_model(model, train_loader, val_loader, config) 805 667 806 - print("\nStep 6: Evaluating model...") 807 668 logger.info("Evaluating model") 808 669 evaluate_model(model, test_loader, config) 809 670 810 - logger.info("Saving misclassified images") 811 - save_misclassified(model, test_loader, config, output_dir="misclassified") 812 - 813 671 logger.info("Generating plots") 814 672 plot_training_history(history) 815 673 816 674 logger.info("Saving model") 817 675 save_model(model, processor, config) 818 676 819 - print("\n" + "=" * 60) 820 - print("✓ Training pipeline complete!") 821 - print("=" * 60) 822 - 823 - print("\nExample: Testing inference on a sample image...") 824 - if len(image_paths) > 0: 825 - sample_image: str = image_paths[0] 826 - prediction, confidence = predict_image( 827 - sample_image, model, processor, config.device, config 828 - ) 829 - print(f"Sample image: {sample_image}") 830 - print(f"Prediction: {prediction} (confidence: {confidence:.2%})") 831 - 832 677 833 678 @cli.command() 834 679 @click.argument("image", type=str, required=True) ··· 853 698 CONFIG.device, 854 699 ) 855 700 856 - resp = requests.get(url, stream=True) 701 + resp = requests.get(url) 857 702 resp.raise_for_status() 858 703 704 + # we can use predict_bytes_single, but we wont so that we can do something with this image like 705 + # move it to a directory for retraining 859 706 with open("./testimg.jpeg", "wb") as f: 860 - resp.raw.decode_content = True 861 - shutil.copyfileobj(resp.raw, f) 707 + f.write(resp.content) 862 708 863 709 print(f"{classifier.predict_single('./testimg.jpeg')}") 864 710 865 711 866 712 if __name__ == "__main__": 867 713 cli() 714 +