import io import base64 from typing import List, Tuple import numpy as np import gradio as gr from datasets import load_dataset from transformers import AutoProcessor, AutoModel import torch from PIL import Image device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 # Load example dataset dataset = load_dataset("xzuyn/dalle-3_vs_sd-v1-5_dpo", num_proc=4) processor_name = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K" model_name = "yuvalkirstain/PickScore_v1" processor = AutoProcessor.from_pretrained(processor_name) model = AutoModel.from_pretrained(model_name, torch_dtype=dtype).to(device) def decode_image(image: str) -> Image: """ Decodes base64 string to PIL image. Args: image: base64 string Returns: PIL image """ img_byte_arr = base64.b64decode(image) img_byte_arr = io.BytesIO(img_byte_arr) img_byte_arr = Image.open(img_byte_arr) return img_byte_arr def get_preference(img_1: Image.Image, img_2: Image.Image, caption: str) -> Image.Image: """ Returns the preference of the caption for the two images. Args: img_1: PIL image img_2: PIL image caption: string Returns: preference image: PIL image """ imgs = [img_1, img_2] logits = get_logits(caption, imgs) preference = logits.argmax().item() return imgs[preference] def sample_example() -> Tuple[Image.Image, Image.Image, Image.Image, str]: """ Samples a random example from the dataset and displays it. Returns: img_1: PIL image img_2: PIL image preference: PIL image caption: string """ example = dataset["train"][np.random.randint(0, len(dataset["train"]))] img_1 = decode_image(example["jpg_0"]) img_2 = decode_image(example["jpg_1"]) caption = example["caption"] imgs = [img_1, img_2] logits = get_logits(caption, imgs) preference = logits.argmax().item() return (img_1, img_2, imgs[preference], caption) def get_logits(caption: str, imgs: List[Image.Image]) -> torch.Tensor: """ Returns the logits for the caption and images. Args: caption: string imgs: list of PIL images Returns: logits: torch.Tensor """ inputs = processor( text=caption, images=imgs, return_tensors="pt", padding=True, truncation=True, max_length=77, ).to(device) inputs["pixel_values"] = ( inputs["pixel_values"].half() if device == "cuda" else inputs["pixel_values"] ) with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image return logits_per_image ### Description title = r"""

Aesthetic Scorer: CLIP fine-tuned for DPO scoring

""" description = r""" This is a demo for the paper Pick-a-Pic: An Open Dataset of User Preferences for Text-to-Image Generation
How to use this demo:
1. Upload two images generated using the same caption. 2. Enter the caption used to generate the images. 3. Click on the "Get Preference" button to get the image which scores higher on user preferences according to the model.
OR
1. Click on the "Random Example" button to get a random example from a Dalle 3 vs SD 1.5 DPO dataset.
This demo demonstrates the use of this CLIP variant for DPO scoring. The scores can then be used for DPO fine-tuning with these scripts.
Accuracy on the Dalle 3 vs SD 1.5 DPO dataset:
PickScore_v1 - 97.3
CLIPSeg - 70.9
CLIP-ViT-H-14-laion2B-s32B-b79K - 82.3
""" citation = r""" 📝 **Citation** ```bibtex @inproceedings{Kirstain2023PickaPicAO, title={Pick-a-Pic: An Open Dataset of User Preferences for Text-to-Image Generation}, author={Yuval Kirstain and Adam Polyak and Uriel Singer and Shahbuland Matiana and Joe Penna and Omer Levy}, year={2023} } ``` """ with gr.Blocks() as demo: gr.Markdown(title) gr.Markdown(description) with gr.Row(): first_image = gr.Image(height=400, width=400, label="First Image") second_image = gr.Image(height=400, width=400, label="Second Image") caption_box = gr.Textbox(lines=1, label="Caption") with gr.Row(): image_button = gr.Button("Get Preference") random_example = gr.Button("Random Example") image_output = gr.Image(height=400, width=400, label="Preference") image_button.click( get_preference, inputs=[first_image, second_image, caption_box], outputs=image_output, ) random_example.click( sample_example, outputs=[first_image, second_image, image_output, caption_box] ) gr.Markdown(citation) if __name__ == "__main__": demo.launch()