import gradio as gr import torch import clip device = "cuda" if torch.cuda.is_available() else "cpu" model, preprocess = clip.load("ViT-B/32", device=device) def predict(image, labels): labels = labels.split(',') image = preprocess(image).unsqueeze(0).to(device) text = clip.tokenize([f"a photo of a {c}" for c in labels]).to(device) with torch.inference_mode(): logits_per_image, logits_per_text = model(image, text) probs = logits_per_image.softmax(dim=-1).cpu().numpy() return {k: float(v) for k, v in zip(labels, probs[0])} # probs = predict(Image.open("../CLIP/CLIP.png"), "cat, dog, ball") # print(probs) gr.Interface(fn=predict, inputs=[ gr.inputs.Image(label="Image to classify.", type="pil"), gr.inputs.Textbox(lines=1, label="Comma separated classes", placeholder="Enter your classes separated by ','",)], theme="grass", outputs="label", description="Zero Shot Image classification..").launch()