this repo has no description

save misclassifications

Changed files
+92
+92
main.py
··· 669 669 return result 670 670 671 671 672 + def save_misclassified( 673 + model: ViTForImageClassification, 674 + test_loader: DataLoader[torch.Tensor], 675 + config: Config, 676 + output_dir: str = "misclassified", 677 + ) -> None: 678 + model.eval() 679 + 680 + fp_dir = Path(output_dir) / "false_positives" 681 + fn_dir = Path(output_dir) / "false_negatives" 682 + fp_dir.mkdir(parents=True, exist_ok=True) 683 + fn_dir.mkdir(parents=True, exist_ok=True) 684 + 685 + dataset = test_loader.dataset 686 + 687 + false_positives = [] 688 + false_negatives = [] 689 + 690 + with torch.no_grad(): 691 + sample_idx = 0 692 + for images, labels in tqdm(test_loader, desc="Finding misclassified"): 693 + images = images.to(config.device) 694 + 695 + outputs = model(pixel_values=images) 696 + logits = outputs.logits 697 + probs = torch.softmax(logits, dim=1) 698 + _, predicted = torch.max(logits, 1) 699 + 700 + for i in range(len(labels)): 701 + true_label = labels[i].item() 702 + pred_label = predicted[i].item() 703 + confidence = probs[i, pred_label].item() 704 + 705 + if true_label != pred_label: 706 + img_path = dataset.image_paths[sample_idx] 707 + 708 + if true_label == 0 and pred_label == 1: 709 + false_positives.append( 710 + { 711 + "path": img_path, 712 + "confidence": confidence, 713 + "nsfw_prob": probs[i, 1].item(), 714 + } 715 + ) 716 + elif true_label == 1 and pred_label == 0: 717 + false_negatives.append( 718 + { 719 + "path": img_path, 720 + "confidence": confidence, 721 + "sfw_prob": probs[i, 0].item(), 722 + } 723 + ) 724 + 725 + sample_idx += 1 726 + 727 + import shutil 728 + 729 + for i, fp in enumerate(false_positives): 730 + src = Path(fp["path"]) 731 + dst = fp_dir / f"{fp['confidence']:.3f}_{src.name}" 732 + shutil.copy(src, dst) 733 + 734 + for i, fn in enumerate(false_negatives): 735 + src = Path(fn["path"]) 736 + dst = fn_dir / f"{fn['confidence']:.3f}_{src.name}" 737 + shutil.copy(src, dst) 738 + 739 + print(f"\n{'=' * 60}") 740 + print("Misclassification Analysis") 741 + print(f"{'=' * 60}") 742 + print(f"False Positives (SFW → NSFW): {len(false_positives)}") 743 + print(f"False Negatives (NSFW → SFW): {len(false_negatives)}") 744 + print(f"\nImages saved to: {output_dir}/") 745 + print(f"- false_positives/ ({len(false_positives)} images)") 746 + print(f"- false_negatives/ ({len(false_negatives)} images)") 747 + 748 + report_path = Path(output_dir) / "report.txt" 749 + with open(report_path, "w") as f: 750 + f.write("FALSE POSITIVES (SFW classified as NSFW)\n") 751 + f.write("=" * 50 + "\n") 752 + for fp in sorted(false_positives, key=lambda x: x["confidence"], reverse=True): 753 + f.write(f"{fp['confidence']:.3f} NSFW prob | {fp['path']}\n") 754 + 755 + f.write("\n\nFALSE NEGATIVES (NSFW classified as SFW)\n") 756 + f.write("=" * 50 + "\n") 757 + for fn in sorted(false_negatives, key=lambda x: x["confidence"], reverse=True): 758 + f.write(f"{fn['confidence']:.3f} SFW prob | {fn['path']}\n") 759 + 760 + 672 761 @click.group() 673 762 def cli(): 674 763 pass ··· 717 806 print("\nStep 6: Evaluating model...") 718 807 logger.info("Evaluating model") 719 808 evaluate_model(model, test_loader, config) 809 + 810 + logger.info("Saving misclassified images") 811 + save_misclassified(model, test_loader, config, output_dir="misclassified") 720 812 721 813 logger.info("Generating plots") 722 814 plot_training_history(history)