MichalMlodawski's picture
Update app.py
89b1df6 verified
import os
import torch
from PIL import Image
from torchvision import transforms
from transformers import AutoProcessor, FocalNetForImageClassification
import gradio as gr
# Path to the model
MODEL_PATH = "MichalMlodawski/nsfw-image-detection-large"
# Load the model and feature extractor
feature_extractor = AutoProcessor.from_pretrained(MODEL_PATH)
model = FocalNetForImageClassification.from_pretrained(MODEL_PATH)
model.eval()
# Image transformations
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Mapping from model labels to NSFW categories
LABEL_TO_CATEGORY = {
"LABEL_0": "Safe",
"LABEL_1": "Questionable",
"LABEL_2": "Unsafe"
}
def classify_image(image):
if image is None:
return "No image uploaded"
# Convert to RGB (in case of PNG with alpha channel)
image = Image.fromarray(image).convert("RGB")
# Process image using feature_extractor
inputs = feature_extractor(images=image, return_tensors="pt")
# Prediction using the model
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predicted = torch.max(probabilities, 1)
# Get the label from the model's configuration
label = model.config.id2label[predicted.item()]
category = LABEL_TO_CATEGORY.get(label, "Unknown")
confidence_value = confidence.item() * 100
# Prepare the result string
emoji = {"Safe": "✅", "Questionable": "⚠️", "Unsafe": "🔞"}.get(category, "❓")
confidence_bar = "🟩" * int(confidence_value // 10) + "⬜" * (10 - int(confidence_value // 10))
#result = f"{emoji} NSFW Category: {category}\n"
result = f"🏷️ Model Label: {label}\n"
result += f"🎯 Confidence: {confidence_value:.2f}% {confidence_bar}"
return result
# Define Gradio interface
iface = gr.Interface(
fn=classify_image,
inputs=gr.Image(type="numpy"),
outputs=gr.Textbox(label="Classification Result"),
title="🖼️ NSFW Image Classification 🔍",
description="Upload an image to classify its safety level!",
theme=gr.themes.Soft(primary_hue="purple"),
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()