+3
-3
config.py
+3
-3
config.py
···
11
11
12
12
model_save_path: str = "./models/nsfw_vit_classifier"
13
13
14
-
model_name: str = "google/vit-base-patch16-224"
14
+
model_name: str = "Falconsai/nsfw_image_detection"
15
15
num_classes: int = 2
16
16
17
-
batch_size: int = 32
17
+
batch_size: int = 16
18
18
num_epochs: int = 10
19
-
learning_rate: float = 2e-5
19
+
learning_rate: float = 5e-6
20
20
weight_decay: float = 0.01
21
21
22
22
train_ratio: float = 0.7
+35
-188
main.py
+35
-188
main.py
···
1
1
import logging
2
-
import shutil
3
2
import random
3
+
from io import BytesIO
4
4
from pathlib import Path
5
5
from typing import Any, Dict, List, Optional, Set, Tuple, Union
6
6
···
43
43
self.transform = transforms.Compose(
44
44
[
45
45
transforms.RandomHorizontalFlip(p=0.5),
46
-
transforms.RandomRotation(20),
46
+
transforms.RandomRotation(15),
47
47
transforms.ColorJitter(
48
-
brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1
48
+
brightness=0.2, contrast=0.2, saturation=0.2
49
49
),
50
-
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
51
-
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
52
-
transforms.RandomGrayscale(p=0.1),
50
+
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
53
51
]
54
52
)
55
53
else:
···
97
95
nsfw_path = Path(nsfw_dir)
98
96
if nsfw_path.exists():
99
97
for f in nsfw_path.iterdir():
100
-
if f.suffix.lower() in valid_extensions or f.name.startswith("did:"):
98
+
if f.suffix.lower() in valid_extensions:
101
99
nsfw_paths.append(str(f))
102
100
else:
103
101
logger.warning(f"NSFW directory not found: {nsfw_dir}")
···
106
104
sfw_humans_path = Path(sfw_human_dir)
107
105
if sfw_humans_path.exists():
108
106
for f in sfw_humans_path.iterdir():
109
-
if f.suffix.lower() in valid_extensions or f.name.startswith("did:"):
107
+
if f.suffix.lower() in valid_extensions:
110
108
sfw_paths.append(str(f))
111
109
else:
112
110
logger.warning(f"SFW humans directory not found: {sfw_human_dir}")
···
115
113
sfw_anime_path = Path(sfw_anime_dir)
116
114
if sfw_anime_path.exists():
117
115
for f in sfw_anime_path.iterdir():
118
-
if f.suffix.lower() in valid_extensions or f.name.startswith("did:"):
116
+
if f.suffix.lower() in valid_extensions:
119
117
sfw_paths.append(str(f))
120
118
else:
121
119
logger.warning(f"SFW anime directory not found: {sfw_anime_dir}")
···
143
141
target_ratio = 1.2
144
142
145
143
if needs_balancing:
146
-
logger.info(f"🔄 Auto-balancing to target ratio: {target_ratio:.2f}")
147
-
148
144
if current_ratio > target_ratio:
149
145
target_nsfw = int(sfw_count * target_ratio)
150
146
target_sfw = sfw_count
···
164
160
sfw_count = len(sfw_paths)
165
161
new_ratio = nsfw_count / sfw_count
166
162
167
-
logger.info(
168
-
f"✅ Dataset after balancing: {nsfw_count} NSFW, {sfw_count} SFW"
169
-
)
163
+
logger.info(f"Dataset after balancing: {nsfw_count} NSFW, {sfw_count} SFW")
170
164
logger.info(f"New NSFW/SFW ratio: {new_ratio:.2f}")
171
165
172
-
else:
173
-
logger.info("✅ Dataset is already balanced, no adjustment needed")
174
-
175
-
elif nsfw_count == 0:
176
-
logger.error("❌ No NSFW images found! Check your data directory.")
177
-
elif sfw_count == 0:
178
-
logger.error("❌ No SFW images found! Check your data directory.")
179
-
180
-
# Combine and create labels
181
166
image_paths: List[str] = []
182
167
labels: List[int] = []
183
168
184
-
# Add NSFW (label = 1)
185
169
for path in nsfw_paths:
186
170
image_paths.append(path)
187
171
labels.append(1)
188
172
189
-
# Add SFW (label = 0)
190
173
for path in sfw_paths:
191
174
image_paths.append(path)
192
175
labels.append(0)
···
196
179
image_paths.append(path)
197
180
labels.append(0)
198
181
199
-
# Shuffle to mix NSFW and SFW
200
182
combined = list(zip(image_paths, labels))
201
183
random.shuffle(combined)
202
184
image_paths, labels = zip(*combined)
203
185
image_paths = list(image_paths)
204
186
labels = list(labels)
205
187
206
-
logger.info(f"📊 Final dataset: {len(image_paths)} total images")
188
+
logger.info(f"Final dataset: {len(image_paths)} total images")
207
189
208
190
return image_paths, labels
209
191
···
283
265
ignore_mismatched_sizes=True,
284
266
)
285
267
286
-
for param in model.vit.embeddings.parameters():
287
-
param.requires_grad = False
288
-
for param in model.vit.encoder.layer[:6].parameters():
289
-
param.requires_grad = False
290
-
291
268
model = model.to(config.device)
292
269
293
270
return model
···
314
291
optimizer.zero_grad()
315
292
316
293
outputs = model(pixel_values=images)
317
-
logits: torch.Tensor = outputs.logits # Shape: (batch_size, num_classes)
294
+
logits: torch.Tensor = outputs.logits
318
295
319
296
loss: torch.Tensor = criterion(logits, labels)
320
297
···
386
363
val_loader: DataLoader[torch.Tensor],
387
364
config: Config,
388
365
):
389
-
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
366
+
criterion = nn.CrossEntropyLoss()
390
367
optimizer = optim.AdamW(
391
368
model.parameters(),
392
369
lr=config.learning_rate,
···
431
408
history["val_acc"].append(val_acc)
432
409
433
410
print(f"\nEpoch {epoch + 1} Summary:")
434
-
print(f" Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}")
435
-
print(f" Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}")
411
+
print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}")
412
+
print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}")
436
413
437
414
if val_acc > best_val_acc:
438
415
best_val_acc = val_acc
···
622
599
"nsfw_prob": probs[1].item(),
623
600
}
624
601
625
-
def predict_batch(
626
-
self, folder_path: str, extensions: Optional[List[str]] = None
627
-
) -> List[Dict[str, Union[str, float]]]:
628
-
if extensions is None:
629
-
extensions = [".jpg", ".jpeg", ".png", ".bmp", ".webp"]
630
-
631
-
folder: Path = Path(folder_path)
632
-
results: List[Dict[str, Union[str, float]]] = []
633
-
634
-
# Get all image files
635
-
image_files: List[Path] = []
636
-
for ext in extensions:
637
-
image_files.extend(folder.glob(f"*{ext}"))
638
-
image_files.extend(folder.glob(f"*{ext.upper()}"))
639
-
640
-
print(f"Found {len(image_files)} images")
641
-
642
-
for img_path in image_files:
643
-
try:
644
-
result: Dict[str, Union[str, float]] = self.predict_single(
645
-
str(img_path)
646
-
)
647
-
result["filename"] = img_path.name
648
-
results.append(result)
649
-
650
-
print(
651
-
f" {img_path.name}: {result['prediction']} "
652
-
f"({result['confidence']:.2%})"
653
-
)
654
-
except Exception as e:
655
-
print(f" Error processing {img_path.name}: {e}")
656
-
657
-
return results
658
-
659
-
def predict_with_threshold(
660
-
self, image_path: str, nsfw_threshold: float = 0.5
661
-
) -> Dict[str, Union[str, float]]:
662
-
result: Dict[str, Union[str, float]] = self.predict_single(image_path)
602
+
def predict_single_bytes(self, image_bytes: bytes) -> Dict[str, Union[str, float]]:
603
+
image = Image.open(BytesIO(image_bytes)).convert("RGB")
604
+
inputs: Dict[str, torch.Tensor] = self.processor(
605
+
images=image, return_tensors="pt"
606
+
)
607
+
pixel_values: torch.Tensor = inputs["pixel_values"].to(self.device)
663
608
664
-
if result["nsfw_prob"] >= nsfw_threshold:
665
-
result["prediction"] = "NSFW"
666
-
else:
667
-
result["prediction"] = "SFW"
609
+
with torch.no_grad():
610
+
outputs = self.model(pixel_values=pixel_values)
611
+
logits: torch.Tensor = outputs.logits
612
+
probs: torch.Tensor = torch.softmax(logits, dim=1)[0]
613
+
predicted_class: int = torch.argmax(probs).item()
668
614
669
-
return result
670
-
671
-
672
-
def save_misclassified(
673
-
model: ViTForImageClassification,
674
-
test_loader: DataLoader[torch.Tensor],
675
-
config: Config,
676
-
output_dir: str = "misclassified",
677
-
) -> None:
678
-
model.eval()
679
-
680
-
fp_dir = Path(output_dir) / "false_positives"
681
-
fn_dir = Path(output_dir) / "false_negatives"
682
-
fp_dir.mkdir(parents=True, exist_ok=True)
683
-
fn_dir.mkdir(parents=True, exist_ok=True)
684
-
685
-
dataset = test_loader.dataset
686
-
687
-
false_positives = []
688
-
false_negatives = []
689
-
690
-
with torch.no_grad():
691
-
sample_idx = 0
692
-
for images, labels in tqdm(test_loader, desc="Finding misclassified"):
693
-
images = images.to(config.device)
694
-
695
-
outputs = model(pixel_values=images)
696
-
logits = outputs.logits
697
-
probs = torch.softmax(logits, dim=1)
698
-
_, predicted = torch.max(logits, 1)
699
-
700
-
for i in range(len(labels)):
701
-
true_label = labels[i].item()
702
-
pred_label = predicted[i].item()
703
-
confidence = probs[i, pred_label].item()
704
-
705
-
if true_label != pred_label:
706
-
img_path = dataset.image_paths[sample_idx]
707
-
708
-
if true_label == 0 and pred_label == 1:
709
-
false_positives.append(
710
-
{
711
-
"path": img_path,
712
-
"confidence": confidence,
713
-
"nsfw_prob": probs[i, 1].item(),
714
-
}
715
-
)
716
-
elif true_label == 1 and pred_label == 0:
717
-
false_negatives.append(
718
-
{
719
-
"path": img_path,
720
-
"confidence": confidence,
721
-
"sfw_prob": probs[i, 0].item(),
722
-
}
723
-
)
724
-
725
-
sample_idx += 1
726
-
727
-
import shutil
728
-
729
-
for i, fp in enumerate(false_positives):
730
-
src = Path(fp["path"])
731
-
dst = fp_dir / f"{fp['confidence']:.3f}_{src.name}"
732
-
shutil.copy(src, dst)
733
-
734
-
for i, fn in enumerate(false_negatives):
735
-
src = Path(fn["path"])
736
-
dst = fn_dir / f"{fn['confidence']:.3f}_{src.name}"
737
-
shutil.copy(src, dst)
738
-
739
-
print(f"\n{'=' * 60}")
740
-
print("Misclassification Analysis")
741
-
print(f"{'=' * 60}")
742
-
print(f"False Positives (SFW → NSFW): {len(false_positives)}")
743
-
print(f"False Negatives (NSFW → SFW): {len(false_negatives)}")
744
-
print(f"\nImages saved to: {output_dir}/")
745
-
print(f"- false_positives/ ({len(false_positives)} images)")
746
-
print(f"- false_negatives/ ({len(false_negatives)} images)")
747
-
748
-
report_path = Path(output_dir) / "report.txt"
749
-
with open(report_path, "w") as f:
750
-
f.write("FALSE POSITIVES (SFW classified as NSFW)\n")
751
-
f.write("=" * 50 + "\n")
752
-
for fp in sorted(false_positives, key=lambda x: x["confidence"], reverse=True):
753
-
f.write(f"{fp['confidence']:.3f} NSFW prob | {fp['path']}\n")
754
-
755
-
f.write("\n\nFALSE NEGATIVES (NSFW classified as SFW)\n")
756
-
f.write("=" * 50 + "\n")
757
-
for fn in sorted(false_negatives, key=lambda x: x["confidence"], reverse=True):
758
-
f.write(f"{fn['confidence']:.3f} SFW prob | {fn['path']}\n")
615
+
return {
616
+
"prediction": self.class_names[predicted_class],
617
+
"confidence": probs[predicted_class].item(),
618
+
"sfw_prob": probs[0].item(),
619
+
"nsfw_prob": probs[1].item(),
620
+
}
759
621
760
622
761
623
@click.group()
···
803
665
804
666
model, history = train_model(model, train_loader, val_loader, config)
805
667
806
-
print("\nStep 6: Evaluating model...")
807
668
logger.info("Evaluating model")
808
669
evaluate_model(model, test_loader, config)
809
670
810
-
logger.info("Saving misclassified images")
811
-
save_misclassified(model, test_loader, config, output_dir="misclassified")
812
-
813
671
logger.info("Generating plots")
814
672
plot_training_history(history)
815
673
816
674
logger.info("Saving model")
817
675
save_model(model, processor, config)
818
676
819
-
print("\n" + "=" * 60)
820
-
print("✓ Training pipeline complete!")
821
-
print("=" * 60)
822
-
823
-
print("\nExample: Testing inference on a sample image...")
824
-
if len(image_paths) > 0:
825
-
sample_image: str = image_paths[0]
826
-
prediction, confidence = predict_image(
827
-
sample_image, model, processor, config.device, config
828
-
)
829
-
print(f"Sample image: {sample_image}")
830
-
print(f"Prediction: {prediction} (confidence: {confidence:.2%})")
831
-
832
677
833
678
@cli.command()
834
679
@click.argument("image", type=str, required=True)
···
853
698
CONFIG.device,
854
699
)
855
700
856
-
resp = requests.get(url, stream=True)
701
+
resp = requests.get(url)
857
702
resp.raise_for_status()
858
703
704
+
# we can use predict_bytes_single, but we wont so that we can do something with this image like
705
+
# move it to a directory for retraining
859
706
with open("./testimg.jpeg", "wb") as f:
860
-
resp.raw.decode_content = True
861
-
shutil.copyfileobj(resp.raw, f)
707
+
f.write(resp.content)
862
708
863
709
print(f"{classifier.predict_single('./testimg.jpeg')}")
864
710
865
711
866
712
if __name__ == "__main__":
867
713
cli()
714
+