|
import gradio as gr |
|
import torch |
|
import clip |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model, preprocess = clip.load("ViT-B/32", device=device) |
|
|
|
|
|
def predict(image, labels): |
|
labels = labels.split(',') |
|
image = preprocess(image).unsqueeze(0).to(device) |
|
text = clip.tokenize([f"a photo of a {c}" for c in labels]).to(device) |
|
|
|
with torch.inference_mode(): |
|
logits_per_image, logits_per_text = model(image, text) |
|
probs = logits_per_image.softmax(dim=-1).cpu().numpy() |
|
|
|
return {k: float(v) for k, v in zip(labels, probs[0])} |
|
|
|
|
|
|
|
|
|
|
|
gr.Interface(fn=predict, |
|
inputs=[ |
|
gr.inputs.Image(label="Image to classify.", type="pil"), |
|
gr.inputs.Textbox(lines=1, label="Comma separated classes", placeholder="Enter your classes separated by ','",)], |
|
theme="grass", |
|
outputs="label", |
|
description="Zero Shot Image classification..").launch() |
|
|