+92
main.py
+92
main.py
···
669
669
return result
670
670
671
671
672
+
def save_misclassified(
673
+
model: ViTForImageClassification,
674
+
test_loader: DataLoader[torch.Tensor],
675
+
config: Config,
676
+
output_dir: str = "misclassified",
677
+
) -> None:
678
+
model.eval()
679
+
680
+
fp_dir = Path(output_dir) / "false_positives"
681
+
fn_dir = Path(output_dir) / "false_negatives"
682
+
fp_dir.mkdir(parents=True, exist_ok=True)
683
+
fn_dir.mkdir(parents=True, exist_ok=True)
684
+
685
+
dataset = test_loader.dataset
686
+
687
+
false_positives = []
688
+
false_negatives = []
689
+
690
+
with torch.no_grad():
691
+
sample_idx = 0
692
+
for images, labels in tqdm(test_loader, desc="Finding misclassified"):
693
+
images = images.to(config.device)
694
+
695
+
outputs = model(pixel_values=images)
696
+
logits = outputs.logits
697
+
probs = torch.softmax(logits, dim=1)
698
+
_, predicted = torch.max(logits, 1)
699
+
700
+
for i in range(len(labels)):
701
+
true_label = labels[i].item()
702
+
pred_label = predicted[i].item()
703
+
confidence = probs[i, pred_label].item()
704
+
705
+
if true_label != pred_label:
706
+
img_path = dataset.image_paths[sample_idx]
707
+
708
+
if true_label == 0 and pred_label == 1:
709
+
false_positives.append(
710
+
{
711
+
"path": img_path,
712
+
"confidence": confidence,
713
+
"nsfw_prob": probs[i, 1].item(),
714
+
}
715
+
)
716
+
elif true_label == 1 and pred_label == 0:
717
+
false_negatives.append(
718
+
{
719
+
"path": img_path,
720
+
"confidence": confidence,
721
+
"sfw_prob": probs[i, 0].item(),
722
+
}
723
+
)
724
+
725
+
sample_idx += 1
726
+
727
+
import shutil
728
+
729
+
for i, fp in enumerate(false_positives):
730
+
src = Path(fp["path"])
731
+
dst = fp_dir / f"{fp['confidence']:.3f}_{src.name}"
732
+
shutil.copy(src, dst)
733
+
734
+
for i, fn in enumerate(false_negatives):
735
+
src = Path(fn["path"])
736
+
dst = fn_dir / f"{fn['confidence']:.3f}_{src.name}"
737
+
shutil.copy(src, dst)
738
+
739
+
print(f"\n{'=' * 60}")
740
+
print("Misclassification Analysis")
741
+
print(f"{'=' * 60}")
742
+
print(f"False Positives (SFW → NSFW): {len(false_positives)}")
743
+
print(f"False Negatives (NSFW → SFW): {len(false_negatives)}")
744
+
print(f"\nImages saved to: {output_dir}/")
745
+
print(f"- false_positives/ ({len(false_positives)} images)")
746
+
print(f"- false_negatives/ ({len(false_negatives)} images)")
747
+
748
+
report_path = Path(output_dir) / "report.txt"
749
+
with open(report_path, "w") as f:
750
+
f.write("FALSE POSITIVES (SFW classified as NSFW)\n")
751
+
f.write("=" * 50 + "\n")
752
+
for fp in sorted(false_positives, key=lambda x: x["confidence"], reverse=True):
753
+
f.write(f"{fp['confidence']:.3f} NSFW prob | {fp['path']}\n")
754
+
755
+
f.write("\n\nFALSE NEGATIVES (NSFW classified as SFW)\n")
756
+
f.write("=" * 50 + "\n")
757
+
for fn in sorted(false_negatives, key=lambda x: x["confidence"], reverse=True):
758
+
f.write(f"{fn['confidence']:.3f} SFW prob | {fn['path']}\n")
759
+
760
+
672
761
@click.group()
673
762
def cli():
674
763
pass
···
717
806
print("\nStep 6: Evaluating model...")
718
807
logger.info("Evaluating model")
719
808
evaluate_model(model, test_loader, config)
809
+
810
+
logger.info("Saving misclassified images")
811
+
save_misclassified(model, test_loader, config, output_dir="misclassified")
720
812
721
813
logger.info("Generating plots")
722
814
plot_training_history(history)