Ertugrul's picture
Update README.md
768ab0d verified
metadata
library_name: transformers
license: apache-2.0
language:
  - en
base_model:
  - mistralai/Pixtral-12B-2409
pipeline_tag: image-to-text

Pixtral-12B-Captioner-Relaxed

Introduction

Pixtral-12B-Captioner-Relaxed is an instruction-tuned version of Pixtral-12B-2409, an advanced multimodal large language model. This fine-tuned version is based on a hand-curated dataset for text-to-image models, providing significantly more detailed descriptions of given images.

Key Features:

  • Enhanced Detail: Generates more comprehensive and nuanced image descriptions.
  • Relaxed Constraints: Offers less restrictive image descriptions compared to the base model.
  • Natural Language Output: Describes different subjects in the image while specifying their locations using natural language.
  • Optimized for Image Generation: Produces captions in formats compatible with state-of-the-art text-to-image generation models.

Note: This fine-tuned model is optimized for creating text-to-image datasets. As a result, performance on other complex tasks may be lower compared to the original model.

Requirements

The 12B model needs 24GB of VRAM at half precision. Model can be loaded at 8 bit or 4 bit quantization but expect degraded performance.

Quickstart

from PIL import Image
from transformers import LlavaForConditionalGeneration, AutoProcessor
from transformers import BitsAndBytesConfig
import torch
import matplotlib.pyplot as plt



# example quantization config, add it to model load parameters to use 4bit quantization
quantization_config = BitsAndBytesConfig(
    # load_in_8bit=True,
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4"
    )



model_id = "Ertugrul/Pixtral-12B-Captioner-Relaxed"
model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16)
processor = AutoProcessor.from_pretrained(model_id)

# for quantization just use this instead of previous load
# model = LlavaForConditionalGeneration.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)

conversation = [
    {
        "role": "user",
        "content": [
            
            {"type": "text", "text": "Describe the image.\n"},
            {
                "type": "image",
            }
        ],
    }
]

PROMPT = processor.apply_chat_template(conversation, add_generation_prompt=True)

image = Image.open(r"PATH_TO_YOUR_IMAGE")

def resize_image(image, target_size=768):
    """Resize the image to have the target size on the shortest side."""
    width, height = image.size
    if width < height:
        new_width = target_size
        new_height = int(height * (new_width / width))
    else:
        new_height = target_size
        new_width = int(width * (new_height / height))
    return image.resize((new_width, new_height), Image.LANCZOS)


# you can try different resolutions or disable it completely
image = resize_image(image, 768)


inputs = processor(text=PROMPT, images=image, return_tensors="pt").to("cuda")


with torch.no_grad():
    with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
        generate_ids = model.generate(**inputs, max_new_tokens=384, do_sample=True, temperature=0.3, use_cache=True, top_k=20)
output_text = processor.batch_decode(generate_ids[:, inputs.input_ids.shape[1]:], skip_special_tokens=True, clean_up_tokenization_spaces=True)[0]

print(output_text)

Acknowledgements

For more detailed options, refer to the Pixtral-12B-2409 or mistral-community/pixtral-12b documentation.

You can also try the Qwen2-VL-7B-Captioner-Relaxed, for an alternative smaller model. It's trianed in a similar manner.