|
import os |
|
import csv |
|
import json |
|
import torch |
|
import argparse |
|
import pandas as pd |
|
from tqdm import tqdm |
|
from peft import LoraConfig, get_peft_model |
|
from torch.utils.data import Dataset, DataLoader |
|
from transformers.models.llama.tokenization_llama import LlamaTokenizer |
|
from mplug_owl_video.modeling_mplug_owl import MplugOwlForConditionalGeneration |
|
from mplug_owl_video.processing_mplug_owl import MplugOwlImageProcessor, MplugOwlProcessor |
|
|
|
|
|
PROMPT_FEEDBACK = '''The following is a conversation between a curious human and AI assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. |
|
Human: <|video|> |
|
Human: What is the misalignment between this video and the description: "{caption}"? |
|
AI: ''' |
|
|
|
generate_kwargs = { |
|
'do_sample': True, |
|
'top_k': 5, |
|
'max_length': 512 |
|
} |
|
|
|
class VideoCaptionDataset(Dataset): |
|
|
|
def __init__(self, videopath, text): |
|
self.videopath = videopath |
|
self.text = text |
|
|
|
def __len__(self): |
|
return 1 |
|
|
|
def __getitem__(self, index): |
|
item = {} |
|
item['videopath'] = self.videopath |
|
item['neg_caption'] = self.text |
|
return item |
|
|
|
def get_nle(model, processor, tokenizer, dataloader): |
|
with torch.no_grad(): |
|
for _, batch in tqdm(enumerate(dataloader)): |
|
videopaths = batch['videopath'] |
|
neg_caption = batch['neg_caption'][0] |
|
prompts = [PROMPT_FEEDBACK.format(caption = neg_caption)] |
|
inputs = processor(text=prompts, videos=videopaths, num_frames=32, return_tensors='pt') |
|
inputs = {k: v.bfloat16() if v.dtype == torch.float else v for k, v in inputs.items()} |
|
inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
res = model.generate(**inputs, **generate_kwargs) |
|
generated_nle = tokenizer.decode(res.tolist()[0], skip_special_tokens=True) |
|
return generated_nle |