|
|
|
import copy |
|
import json |
|
import os |
|
import random |
|
from dataclasses import dataclass |
|
from typing import Dict, List, Sequence |
|
|
|
import numpy as np |
|
import tokenizers |
|
|
|
import torch |
|
|
|
import transformers |
|
|
|
from longvu import conversation as conversation_lib |
|
|
|
from longvu.constants import ( |
|
DEFAULT_IM_END_TOKEN, |
|
DEFAULT_IM_START_TOKEN, |
|
DEFAULT_IMAGE_TOKEN, |
|
IGNORE_INDEX, |
|
IMAGE_TOKEN_INDEX, |
|
) |
|
|
|
|
|
from decord import cpu, VideoReader |
|
|
|
from packaging import version |
|
from PIL import Image |
|
from torch import distributed as dist |
|
from torch.distributed.fsdp import ( |
|
FullStateDictConfig, |
|
FullyShardedDataParallel as FSDP, |
|
StateDictType, |
|
) |
|
from torch.utils.data import Dataset |
|
|
|
|
|
IS_TOKENIZER_GREATER_THAN_0_14 = version.parse(tokenizers.__version__) >= version.parse( |
|
"0.14" |
|
) |
|
from transformers import StoppingCriteria |
|
|
|
from longvu.mm_utils import KeywordsStoppingCriteria |
|
|
|
|
|
|
|
|
|
def maybe_zero_3(param, ignore_status: bool = False, name=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return param.detach().cpu().clone() |
|
|
|
|
|
|
|
|
|
def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): |
|
to_return = { |
|
k: t |
|
for k, t in named_params |
|
if any(key_match in k for key_match in keys_to_match) |
|
} |
|
to_return = { |
|
k: maybe_zero_3(v, ignore_status=True, name=k).cpu() |
|
for k, v in to_return.items() |
|
} |
|
return to_return |
|
|
|
|
|
|
|
|
|
def find_all_linear_names(model): |
|
cls = torch.nn.Linear |
|
lora_module_names = set() |
|
multimodal_keywords = ["mm_projector", "vision_tower", "vision_resampler"] |
|
for name, module in model.named_modules(): |
|
if any(mm_keyword in name for mm_keyword in multimodal_keywords): |
|
continue |
|
if isinstance(module, cls): |
|
names = name.split(".") |
|
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) |
|
|
|
if "lm_head" in lora_module_names: |
|
lora_module_names.remove("lm_head") |
|
return list(lora_module_names) |
|
|
|
|
|
def safe_save_model_for_hf_trainer( |
|
trainer: transformers.Trainer, output_dir: str |
|
) -> None: |
|
"""Collects the state dict and dump to disk.""" |
|
global_rank = dist.get_rank() |
|
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
|
|
|
if len(trainer.args.fsdp) == 0: |
|
|
|
cpu_state_dict = trainer.model.state_dict() |
|
else: |
|
with FSDP.state_dict_type( |
|
trainer.model, StateDictType.FULL_STATE_DICT, save_policy |
|
): |
|
cpu_state_dict = trainer.model.state_dict() |
|
|
|
for key in cpu_state_dict.keys(): |
|
cpu_state_dict[key] = cpu_state_dict[key].to(torch.bfloat16) |
|
|
|
if global_rank == 0: |
|
trainer.model.config.save_pretrained(output_dir) |
|
current_folder = output_dir.split("/")[-1] |
|
parent_folder = os.path.dirname(output_dir) |
|
save_path = os.path.join(output_dir, "pytorch_model.bin") |
|
if getattr(trainer.args, "tune_mm_mlp_adapter", False) and not getattr( |
|
trainer.args, "tune_text_decoder", False |
|
): |
|
|
|
keys_to_match = ["mm_projector"] |
|
if getattr(trainer.args, "use_im_start_end", False): |
|
keys_to_match.extend(["embed_tokens", "embed_in"]) |
|
|
|
freeze_layer_remove = [] |
|
for key in cpu_state_dict.keys(): |
|
remove = True |
|
for key_match in keys_to_match: |
|
if key_match in key: |
|
remove = False |
|
break |
|
if remove: |
|
freeze_layer_remove.append(key) |
|
for key in freeze_layer_remove: |
|
del cpu_state_dict[key] |
|
|
|
if current_folder.startswith("checkpoint-"): |
|
mm_projector_folder = os.path.join(parent_folder, "mm_projector") |
|
os.makedirs(mm_projector_folder, exist_ok=True) |
|
save_path = os.path.join(mm_projector_folder, f"{current_folder}.bin") |
|
else: |
|
save_path = os.path.join(output_dir, f"mm_projector.bin") |
|
torch.save(cpu_state_dict, save_path) |
|
|
|
|
|
def smart_tokenizer_and_embedding_resize( |
|
|
|
|
|
special_tokens_dict: Dict, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
model: transformers.PreTrainedModel, |
|
) -> None: |
|
"""Resize tokenizer and embedding. |
|
|
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. |
|
""" |
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
|
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
|
|
input_embeddings = model.get_input_embeddings().weight.data |
|
|
|
output_embeddings = model.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True |
|
) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean( |
|
dim=0, keepdim=True |
|
) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
|
|
def _tokenize_fn( |
|
strings: Sequence[str], |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
|
|
|
) -> Dict: |
|
"""Tokenize a list of strings.""" |
|
tokenized_list = [ |
|
tokenizer( |
|
text, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
) |
|
for text in strings |
|
] |
|
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] |
|
input_ids_lens = labels_lens = [ |
|
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() |
|
for tokenized in tokenized_list |
|
] |
|
return dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
input_ids_lens=input_ids_lens, |
|
labels_lens=labels_lens, |
|
) |
|
|
|
|
|
|
|
def _mask_targets(target, tokenized_lens, speakers) -> None: |
|
|
|
cur_idx = tokenized_lens[0] |
|
tokenized_lens = tokenized_lens[1:] |
|
target[:cur_idx] = IGNORE_INDEX |
|
for tokenized_len, speaker in zip(tokenized_lens, speakers): |
|
if speaker == "human": |
|
target[cur_idx + 2 : cur_idx + tokenized_len] = IGNORE_INDEX |
|
cur_idx += tokenized_len |
|
|
|
|
|
|
|
|
|
def _add_speaker_and_signal(header, source, get_conversation: bool = True): |
|
"""Add speaker and start/end signal on each round.""" |
|
BEGIN_SIGNAL = "### " |
|
END_SIGNAL = "\n" |
|
conversation = header |
|
for sentence in source: |
|
from_str = sentence["from"] |
|
if from_str.lower() == "human": |
|
from_str = conversation_lib.default_conversation.roles[0] |
|
elif from_str.lower() == "gpt": |
|
from_str = conversation_lib.default_conversation.roles[1] |
|
else: |
|
from_str = "unknown" |
|
sentence["value"] = ( |
|
BEGIN_SIGNAL + from_str + ": " + sentence["value"] + END_SIGNAL |
|
) |
|
if get_conversation: |
|
conversation += sentence["value"] |
|
conversation += BEGIN_SIGNAL |
|
return conversation |
|
|
|
|
|
|
|
|
|
def expand2square(pil_img, background_color): |
|
width, height = pil_img.size |
|
if width == height: |
|
return pil_img |
|
elif width > height: |
|
result = Image.new(pil_img.mode, (width, width), background_color) |
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
return result |
|
else: |
|
result = Image.new(pil_img.mode, (height, height), background_color) |
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
return result |
|
|
|
|
|
|
|
|
|
def process_images(images, image_processor, model_cfg): |
|
if isinstance(image_processor, list): |
|
processor_aux_list = image_processor |
|
new_images_aux_list = [] |
|
for image in images: |
|
if isinstance(image, np.ndarray): |
|
image = Image.fromarray(image) |
|
image_aux_list = [] |
|
for processor_aux in processor_aux_list: |
|
image_aux = image |
|
if hasattr(processor_aux, "image_mean"): |
|
try: |
|
target_resolution = processor_aux.crop_size["height"] |
|
except: |
|
target_resolution = processor_aux.size["height"] |
|
image_aux = expand2square( |
|
image_aux, tuple(int(x * 255) for x in processor_aux.image_mean) |
|
).resize((target_resolution, target_resolution)) |
|
image_aux = processor_aux.preprocess(image_aux, return_tensors="pt")[ |
|
"pixel_values" |
|
][0] |
|
image_aux_list.append(image_aux) |
|
new_images_aux_list.append(image_aux_list) |
|
new_images_aux_list = [ |
|
list(batch_image_aux) for batch_image_aux in zip(*new_images_aux_list) |
|
] |
|
new_images_aux_list = [ |
|
torch.stack(image_aux).half().cuda() for image_aux in new_images_aux_list |
|
] |
|
return new_images_aux_list |
|
else: |
|
image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) |
|
new_images = [] |
|
if image_aspect_ratio == "pad": |
|
for image in images: |
|
image = expand2square( |
|
image, tuple(int(x * 255) for x in image_processor.image_mean) |
|
) |
|
image = image_processor.preprocess(image, return_tensors="pt")[ |
|
"pixel_values" |
|
][0] |
|
new_images.append(image) |
|
else: |
|
return image_processor(images, return_tensors="pt")["pixel_values"] |
|
if all(x.shape == new_images[0].shape for x in new_images): |
|
new_images = torch.stack(new_images, dim=0) |
|
return new_images |
|
|
|
|
|
|
|
|
|
|
|
def preprocess_multimodal(sources: Sequence[str], data_args) -> Dict: |
|
is_multimodal = data_args.is_multimodal |
|
if not is_multimodal: |
|
|
|
|
|
return sources |
|
|
|
for source in sources: |
|
for sentence in source: |
|
if ( |
|
|
|
|
|
DEFAULT_IMAGE_TOKEN in sentence["value"] |
|
|
|
|
|
or "<video>" in sentence["value"] |
|
): |
|
|
|
sentence["value"] = ( |
|
|
|
|
|
sentence["value"] |
|
.replace(DEFAULT_IMAGE_TOKEN, "") |
|
.replace("<video>", "") |
|
.strip() |
|
) |
|
|
|
|
|
sentence["value"] = DEFAULT_IMAGE_TOKEN + "\n" + sentence["value"] |
|
|
|
|
|
sentence["value"] = sentence["value"].strip() |
|
if "mmtag" in conversation_lib.default_conversation.version: |
|
|
|
|
|
sentence["value"] = sentence["value"].replace( |
|
DEFAULT_IMAGE_TOKEN, |
|
"<Image>" + DEFAULT_IMAGE_TOKEN + "</Image>", |
|
) |
|
replace_token = DEFAULT_IMAGE_TOKEN |
|
if data_args.mm_use_im_start_end: |
|
replace_token = ( |
|
DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN |
|
) |
|
|
|
|
|
sentence["value"] = sentence["value"].replace( |
|
DEFAULT_IMAGE_TOKEN, replace_token |
|
) |
|
|
|
|
|
return sources |
|
|
|
|
|
def preprocess_llama_2( |
|
|
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False, |
|
|
|
|
|
) -> Dict: |
|
conv = conversation_lib.default_conversation.copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence["from"]] |
|
assert role == conv.roles[j % 2], f"{i}" |
|
conv.append_message(role, sentence["value"]) |
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
|
|
if has_image: |
|
input_ids = torch.stack( |
|
[ |
|
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
|
for prompt in conversations |
|
], |
|
dim=0, |
|
) |
|
else: |
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
|
|
targets = input_ids.clone() |
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_2 |
|
|
|
|
|
sep = "[/INST] " |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
rounds = conversation.split(conv.sep2) |
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_INDEX |
|
for i, rou in enumerate(rounds): |
|
if rou == "": |
|
break |
|
|
|
parts = rou.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
|
|
if has_image: |
|
round_len = len(tokenizer_image_token(rou, tokenizer)) |
|
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 |
|
else: |
|
round_len = len(tokenizer(rou).input_ids) |
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
|
cur_len += round_len |
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_INDEX |
|
print( |
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
f" (ignored)" |
|
) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
|
|
def preprocess_v1( |
|
|
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False, |
|
|
|
|
|
) -> Dict: |
|
conv = conversation_lib.default_conversation.copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence["from"]] |
|
assert role == conv.roles[j % 2], f"{i}" |
|
conv.append_message(role, sentence["value"]) |
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
|
|
if has_image: |
|
input_ids = torch.stack( |
|
[ |
|
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
|
for prompt in conversations |
|
], |
|
dim=0, |
|
) |
|
else: |
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
|
|
targets = input_ids.clone() |
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO |
|
|
|
|
|
sep = conv.sep + conv.roles[1] + ": " |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
rounds = conversation.split(conv.sep2) |
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_INDEX |
|
for i, rou in enumerate(rounds): |
|
if rou == "": |
|
break |
|
|
|
parts = rou.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
|
|
if has_image: |
|
round_len = len(tokenizer_image_token(rou, tokenizer)) |
|
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 2 |
|
else: |
|
round_len = len(tokenizer(rou).input_ids) |
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 2 |
|
|
|
if i != 0 and not tokenizer.legacy and IS_TOKENIZER_GREATER_THAN_0_14: |
|
round_len -= 1 |
|
instruction_len -= 1 |
|
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
|
cur_len += round_len |
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_INDEX |
|
print( |
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
f" (ignored)" |
|
) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
|
|
|
|
def tokenizer_image_token( |
|
|
|
prompt, |
|
|
|
tokenizer, |
|
|
|
image_token_index=IMAGE_TOKEN_INDEX, |
|
|
|
return_tensors=None, |
|
): |
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")] |
|
|
|
|
|
|
|
def insert_separator(X, sep): |
|
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] |
|
|
|
input_ids = [] |
|
offset = 0 |
|
if ( |
|
len(prompt_chunks) > 0 |
|
and len(prompt_chunks[0]) > 0 |
|
and prompt_chunks[0][0] == tokenizer.bos_token_id |
|
): |
|
offset = 1 |
|
input_ids.append(prompt_chunks[0][0]) |
|
|
|
for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): |
|
input_ids.extend(x[offset:]) |
|
|
|
if return_tensors is not None: |
|
if return_tensors == "pt": |
|
return torch.tensor(input_ids, dtype=torch.long) |
|
raise ValueError(f"Unsupported tensor type: {return_tensors}") |
|
return input_ids |
|
|
|
|
|
|
|
def tokenizer_image_token_llama3( |
|
|
|
prompt, |
|
|
|
tokenizer, |
|
|
|
image_token_index=IMAGE_TOKEN_INDEX, |
|
|
|
return_tensors=None, |
|
): |
|
prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split("<image>")] |
|
|
|
|
|
|
|
def insert_separator(X, sep): |
|
return [ele for sublist in zip(X, [sep] * len(X)) for ele in sublist][:-1] |
|
|
|
input_ids = [] |
|
for x in insert_separator(prompt_chunks, [image_token_index]): |
|
input_ids.extend(x) |
|
|
|
if return_tensors is not None: |
|
if return_tensors == "pt": |
|
return torch.tensor(input_ids, dtype=torch.long) |
|
raise ValueError(f"Unsupported tensor type: {return_tensors}") |
|
return input_ids |
|
|
|
|
|
def preprocess_qwen( |
|
|
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False, |
|
system_message: str = "You are a helpful assistant.", |
|
|
|
|
|
) -> Dict: |
|
|
|
roles = {"human": "user", "gpt": "assistant"} |
|
|
|
|
|
|
|
tokenizer = copy.deepcopy(tokenizer) |
|
|
|
if has_image: |
|
tokenizer.add_tokens(["<image>"], special_tokens=True) |
|
|
|
image_token_index = tokenizer.convert_tokens_to_ids("<image>") |
|
im_start, im_end = tokenizer.additional_special_tokens_ids |
|
|
|
unmask_tokens_idx = [198, im_start, im_end] |
|
nl_tokens = tokenizer("\n").input_ids |
|
|
|
|
|
chat_template = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" |
|
tokenizer.chat_template = chat_template |
|
|
|
|
|
|
|
|
|
|
|
|
|
input_ids, targets = [], [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != roles["human"]: |
|
source = source[1:] |
|
|
|
input_id, target = [], [] |
|
|
|
|
|
|
|
input_id += tokenizer.apply_chat_template( |
|
[{"role": "system", "content": system_message}] |
|
) |
|
target += [IGNORE_INDEX] * len(input_id) |
|
|
|
for conv in source: |
|
|
|
try: |
|
role = conv["role"] |
|
content = conv["content"] |
|
except: |
|
role = conv["from"] |
|
content = conv["value"] |
|
|
|
role = roles.get(role, role) |
|
|
|
conv = [{"role": role, "content": content}] |
|
encode_id = tokenizer.apply_chat_template(conv) |
|
input_id += encode_id |
|
if role in ["user", "system"]: |
|
target += [IGNORE_INDEX] * len(encode_id) |
|
else: |
|
target += encode_id |
|
|
|
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" |
|
for idx, encode_id in enumerate(input_id): |
|
if encode_id in unmask_tokens_idx: |
|
target[idx] = encode_id |
|
if encode_id == image_token_index: |
|
input_id[idx] = IMAGE_TOKEN_INDEX |
|
input_ids.append(input_id) |
|
targets.append(target) |
|
input_ids = torch.tensor(input_ids, dtype=torch.long) |
|
targets = torch.tensor(targets, dtype=torch.long) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
|
|
def preprocess_llama3( |
|
|
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False, |
|
system_message: str = "You are a helpful assistant.", |
|
|
|
|
|
) -> Dict: |
|
|
|
roles = {"human": "user", "gpt": "assistant"} |
|
|
|
|
|
|
|
tokenizer = copy.deepcopy(tokenizer) |
|
|
|
if has_image: |
|
tokenizer.add_tokens(["<image>"], special_tokens=True) |
|
image_token_index = tokenizer.convert_tokens_to_ids("<image>") |
|
bos_token_id = tokenizer.convert_tokens_to_ids("<|begin_of_text|>") |
|
start_header_id = tokenizer.convert_tokens_to_ids("<|start_header_id|>") |
|
end_header_id = tokenizer.convert_tokens_to_ids("<|end_header_id|>") |
|
eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") |
|
|
|
unmask_tokens = [ |
|
"<|begin_of_text|>", |
|
"<|start_header_id|>", |
|
"<|end_header_id|>", |
|
"<|eot_id|>", |
|
"\n\n", |
|
] |
|
unmask_tokens_idx = [tokenizer.convert_tokens_to_ids(tok) for tok in unmask_tokens] |
|
|
|
|
|
|
|
|
|
|
|
|
|
def safe_tokenizer_llama3(text): |
|
input_ids = tokenizer(text).input_ids |
|
if input_ids[0] == bos_token_id: |
|
input_ids = input_ids[1:] |
|
return input_ids |
|
|
|
nl_tokens = tokenizer.convert_tokens_to_ids("\n\n") |
|
|
|
|
|
chat_template = "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}" |
|
tokenizer.chat_template = chat_template |
|
|
|
|
|
input_ids, targets = [], [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != roles["human"]: |
|
source = source[1:] |
|
|
|
input_id, target = [], [] |
|
|
|
|
|
|
|
input_id += tokenizer.apply_chat_template( |
|
[{"role": "system", "content": system_message}] |
|
|
|
)[:-4] |
|
|
|
target += [IGNORE_INDEX] * len(input_id) |
|
|
|
for conv in source: |
|
|
|
try: |
|
role = conv["role"] |
|
content = conv["content"] |
|
except: |
|
role = conv["from"] |
|
content = conv["value"] |
|
|
|
role = roles.get(role, role) |
|
|
|
conv = [{"role": role, "content": content}] |
|
|
|
|
|
|
|
encode_id = tokenizer.apply_chat_template(conv)[1:-4] |
|
input_id += encode_id |
|
if role in ["user", "system"]: |
|
target += [IGNORE_INDEX] * len(encode_id) |
|
else: |
|
target += encode_id |
|
|
|
assert len(input_id) == len(target), f"{len(input_id)} != {len(target)}" |
|
for idx, encode_id in enumerate(input_id): |
|
if encode_id in unmask_tokens_idx: |
|
target[idx] = encode_id |
|
if encode_id == image_token_index: |
|
input_id[idx] = IMAGE_TOKEN_INDEX |
|
input_ids.append(input_id) |
|
targets.append(target) |
|
input_ids = torch.tensor(input_ids, dtype=torch.long) |
|
targets = torch.tensor(targets, dtype=torch.long) |
|
|
|
print("input_ids", input_ids, flush=True) |
|
print("targets", targets, flush=True) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
|
|
def preprocess_llama_3_1( |
|
|
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False, |
|
|
|
|
|
) -> Dict: |
|
conv = conversation_lib.default_conversation.copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
if sentence["from"] == "Answer": |
|
sentence["from"] = "gpt" |
|
role = roles[sentence["from"]] |
|
|
|
conv.append_message(role, sentence["value"]) |
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
|
|
if has_image: |
|
input_ids = torch.stack( |
|
[ |
|
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
|
for prompt in conversations |
|
], |
|
dim=0, |
|
) |
|
else: |
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
|
|
|
|
if input_ids[0][0] == input_ids[0][1] == tokenizer.bos_token_id: |
|
input_ids = input_ids[:, 1:] |
|
targets = input_ids.clone() |
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3_1 |
|
|
|
|
|
sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>" + "\n\n" |
|
|
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.shape[0]) |
|
|
|
rounds = conversation.split(conv.tokenizer.eos_token) |
|
rounds = [rounds[0]] + [ |
|
rounds[idx] + rounds[idx + 1] for idx in range(1, len(rounds) - 1, 2) |
|
] |
|
|
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_INDEX |
|
for i, rou in enumerate(rounds): |
|
if rou == "": |
|
break |
|
|
|
parts = rou.split(sep) |
|
if len(parts) != 2 and i != 0: |
|
break |
|
|
|
if i == 0: |
|
round_len = len(tokenizer(rou, add_special_tokens=False).input_ids) |
|
instruction_len = len( |
|
tokenizer(rou, add_special_tokens=False).input_ids |
|
) |
|
|
|
else: |
|
parts[0] += sep |
|
if has_image: |
|
round_len = len(tokenizer_image_token(rou, tokenizer)) + 1 |
|
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) |
|
else: |
|
round_len = len(tokenizer(rou).input_ids) + 1 |
|
instruction_len = len(tokenizer(parts[0]).input_ids) |
|
|
|
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
cur_len += round_len |
|
|
|
target[cur_len:] = IGNORE_INDEX |
|
cur_len = cur_len + len(tokenizer(sep, add_special_tokens=False).input_ids) |
|
|
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_INDEX |
|
print( |
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
f" (ignored)" |
|
) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
|
|
def preprocess_llama_3_2( |
|
|
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False, |
|
|
|
|
|
) -> Dict: |
|
conv = conversation_lib.default_conversation.copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence["from"]] |
|
assert role == conv.roles[j % 2], f"{i}" |
|
conv.append_message(role, sentence["value"]) |
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
|
|
if has_image: |
|
input_ids = torch.stack( |
|
[ |
|
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
|
for prompt in conversations |
|
], |
|
dim=0, |
|
) |
|
else: |
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
|
|
|
|
if input_ids[0][0] == input_ids[0][1] == tokenizer.bos_token_id: |
|
input_ids = input_ids[:, 1:] |
|
targets = input_ids.clone() |
|
|
|
assert conv.sep_style == conversation_lib.SeparatorStyle.LLAMA_3_2 |
|
|
|
|
|
sep = "<|start_header_id|>" + conv.roles[1] + "<|end_header_id|>" + "\n\n" |
|
|
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.shape[0]) |
|
|
|
rounds = conversation.split(conv.tokenizer.eos_token) |
|
rounds = [rounds[0]] + [ |
|
rounds[idx] + rounds[idx + 1] for idx in range(1, len(rounds) - 1, 2) |
|
] |
|
|
|
cur_len = 1 |
|
target[:cur_len] = IGNORE_INDEX |
|
for i, rou in enumerate(rounds): |
|
if rou == "": |
|
break |
|
|
|
parts = rou.split(sep) |
|
if len(parts) != 2 and i != 0: |
|
break |
|
|
|
if i == 0: |
|
round_len = len(tokenizer(rou, add_special_tokens=False).input_ids) |
|
instruction_len = len( |
|
tokenizer(rou, add_special_tokens=False).input_ids |
|
) |
|
|
|
else: |
|
parts[0] += sep |
|
if has_image: |
|
round_len = len(tokenizer_image_token(rou, tokenizer)) + 1 |
|
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) |
|
else: |
|
round_len = len(tokenizer(rou).input_ids) + 1 |
|
instruction_len = len(tokenizer(parts[0]).input_ids) |
|
|
|
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
cur_len += round_len |
|
|
|
target[cur_len:] = IGNORE_INDEX |
|
cur_len = cur_len + len(tokenizer(sep, add_special_tokens=False).input_ids) |
|
|
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_INDEX |
|
print( |
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
f" (ignored)" |
|
) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
|
|
def preprocess_phi3( |
|
|
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False, |
|
|
|
|
|
) -> Dict: |
|
conv = conversation_lib.conv_templates["phi3"].copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence["from"]] |
|
assert role == conv.roles[j % 2], f"{i}" |
|
conv.append_message(role, sentence["value"]) |
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
if has_image: |
|
input_ids = torch.stack( |
|
[ |
|
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
|
for prompt in conversations |
|
], |
|
dim=0, |
|
) |
|
else: |
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
|
|
targets = input_ids.clone() |
|
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT |
|
|
|
|
|
sep = conv.sep + conv.roles[1] |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
rounds = conversation.split(conv.sep) |
|
re_rounds = [conv.sep.join(rounds[:3])] |
|
for conv_idx in range(3, len(rounds), 2): |
|
re_rounds.append( |
|
conv.sep.join(rounds[conv_idx : conv_idx + 2]) |
|
) |
|
cur_len = 0 |
|
target[:cur_len] = IGNORE_INDEX |
|
for i, rou in enumerate(re_rounds): |
|
if rou == "": |
|
break |
|
|
|
parts = rou.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
|
|
if has_image: |
|
round_len = len(tokenizer_image_token(rou, tokenizer)) |
|
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 |
|
else: |
|
round_len = len(tokenizer(rou).input_ids) |
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 1 |
|
|
|
if i == 0: |
|
round_len += 1 |
|
instruction_len += 1 |
|
else: |
|
round_len -= 2 |
|
instruction_len -= 2 |
|
|
|
if ( |
|
i != 0 |
|
and getattr(tokenizer, "legacy", False) |
|
and IS_TOKENIZER_GREATER_THAN_0_14 |
|
): |
|
round_len += 1 |
|
instruction_len += 1 |
|
|
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
|
cur_len += round_len |
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_INDEX |
|
print( |
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
f" (ignored)" |
|
) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
|
|
def preprocess_mpt( |
|
|
|
sources, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False, |
|
|
|
) -> Dict: |
|
conv = conversation_lib.default_conversation.copy() |
|
roles = {"human": conv.roles[0], "gpt": conv.roles[1]} |
|
|
|
|
|
conversations = [] |
|
for i, source in enumerate(sources): |
|
if roles[source[0]["from"]] != conv.roles[0]: |
|
|
|
source = source[1:] |
|
|
|
conv.messages = [] |
|
for j, sentence in enumerate(source): |
|
role = roles[sentence["from"]] |
|
assert role == conv.roles[j % 2], f"{i}" |
|
conv.append_message(role, sentence["value"]) |
|
conversations.append(conv.get_prompt()) |
|
|
|
|
|
if has_image: |
|
input_ids = torch.stack( |
|
[ |
|
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
|
for prompt in conversations |
|
], |
|
dim=0, |
|
) |
|
else: |
|
input_ids = tokenizer( |
|
conversations, |
|
return_tensors="pt", |
|
padding="longest", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
).input_ids |
|
|
|
targets = input_ids.clone() |
|
assert conv.sep_style == conversation_lib.SeparatorStyle.MPT |
|
|
|
|
|
sep = conv.sep + conv.roles[1] |
|
for conversation, target in zip(conversations, targets): |
|
total_len = int(target.ne(tokenizer.pad_token_id).sum()) |
|
|
|
rounds = conversation.split(conv.sep) |
|
re_rounds = [conv.sep.join(rounds[:3])] |
|
for conv_idx in range(3, len(rounds), 2): |
|
re_rounds.append( |
|
conv.sep.join(rounds[conv_idx : conv_idx + 2]) |
|
) |
|
cur_len = 0 |
|
target[:cur_len] = IGNORE_INDEX |
|
for i, rou in enumerate(re_rounds): |
|
if rou == "": |
|
break |
|
|
|
parts = rou.split(sep) |
|
if len(parts) != 2: |
|
break |
|
parts[0] += sep |
|
if has_image: |
|
round_len = len(tokenizer_image_token(rou, tokenizer)) |
|
instruction_len = len(tokenizer_image_token(parts[0], tokenizer)) - 1 |
|
else: |
|
round_len = len(tokenizer(rou).input_ids) |
|
instruction_len = len(tokenizer(parts[0]).input_ids) - 1 |
|
|
|
if ( |
|
i != 0 |
|
and getattr(tokenizer, "legacy", False) |
|
and IS_TOKENIZER_GREATER_THAN_0_14 |
|
): |
|
round_len += 1 |
|
instruction_len += 1 |
|
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX |
|
|
|
cur_len += round_len |
|
target[cur_len:] = IGNORE_INDEX |
|
|
|
if cur_len < tokenizer.model_max_length: |
|
if cur_len != total_len: |
|
target[:] = IGNORE_INDEX |
|
print( |
|
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}." |
|
f" (ignored)" |
|
) |
|
|
|
return dict( |
|
input_ids=input_ids, |
|
labels=targets, |
|
) |
|
|
|
|
|
def preprocess_plain( |
|
sources: Sequence[str], |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
|
|
|
) -> Dict: |
|
|
|
conversations = [] |
|
for source in sources: |
|
assert len(source) == 2 |
|
|
|
|
|
assert DEFAULT_IMAGE_TOKEN in source[0]["value"] |
|
|
|
source[0]["value"] = DEFAULT_IMAGE_TOKEN |
|
conversation = ( |
|
|
|
|
|
source[0]["value"] |
|
|
|
|
|
+ source[1]["value"] |
|
+ conversation_lib.default_conversation.sep |
|
) |
|
conversations.append(conversation) |
|
|
|
input_ids = [ |
|
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
|
for prompt in conversations |
|
] |
|
targets = copy.deepcopy(input_ids) |
|
for target, source in zip(targets, sources): |
|
|
|
|
|
tokenized_len = len(tokenizer_image_token(source[0]["value"], tokenizer)) |
|
target[:tokenized_len] = IGNORE_INDEX |
|
|
|
return dict(input_ids=input_ids, labels=targets) |
|
|
|
|
|
def preprocess( |
|
sources: Sequence[str], |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
has_image: bool = False, |
|
|
|
|
|
) -> Dict: |
|
""" |
|
Given a list of sources, each is a conversation list. This transform: |
|
1. Add signal '### ' at the beginning each sentence, with end signal '\n'; |
|
2. Concatenate conversations together; |
|
3. Tokenize the concatenated conversation; |
|
4. Make a deepcopy as the target. Mask human words with IGNORE_INDEX. |
|
""" |
|
if ( |
|
conversation_lib.default_conversation.sep_style |
|
== conversation_lib.SeparatorStyle.PLAIN |
|
): |
|
return preprocess_plain(sources, tokenizer) |
|
if ( |
|
conversation_lib.default_conversation.sep_style |
|
== conversation_lib.SeparatorStyle.LLAMA_2 |
|
): |
|
return preprocess_llama_2(sources, tokenizer, has_image=has_image) |
|
if conversation_lib.default_conversation.version.startswith("v1"): |
|
return preprocess_v1(sources, tokenizer, has_image=has_image) |
|
if conversation_lib.default_conversation.version == "mpt": |
|
return preprocess_mpt(sources, tokenizer, has_image=has_image) |
|
if conversation_lib.default_conversation.version == "llama3": |
|
return preprocess_llama3(sources, tokenizer, has_image=has_image) |
|
if conversation_lib.default_conversation.version == "llama3_1": |
|
return preprocess_llama_3_1(sources, tokenizer, has_image=has_image) |
|
if conversation_lib.default_conversation.version == "llama3_2": |
|
return preprocess_llama_3_2(sources, tokenizer, has_image=has_image) |
|
if conversation_lib.default_conversation.version == "phi3": |
|
return preprocess_phi3(sources, tokenizer, has_image=has_image) |
|
if conversation_lib.default_conversation.version == "qwen": |
|
return preprocess_qwen(sources, tokenizer, has_image=has_image) |
|
|
|
conversations = [] |
|
for source in sources: |
|
header = f"{conversation_lib.default_conversation.system}\n\n" |
|
conversation = _add_speaker_and_signal(header, source) |
|
conversations.append(conversation) |
|
|
|
|
|
|
|
|
|
def get_tokenize_len(prompts): |
|
return [len(tokenizer_image_token(prompt, tokenizer)) for prompt in prompts] |
|
|
|
if has_image: |
|
input_ids = [ |
|
tokenizer_image_token(prompt, tokenizer, return_tensors="pt") |
|
for prompt in conversations |
|
] |
|
else: |
|
conversations_tokenized = _tokenize_fn(conversations, tokenizer) |
|
input_ids = conversations_tokenized["input_ids"] |
|
|
|
targets = copy.deepcopy(input_ids) |
|
for target, source in zip(targets, sources): |
|
if has_image: |
|
|
|
|
|
|
|
tokenized_lens = get_tokenize_len([header] + [s["value"] for s in source]) |
|
else: |
|
tokenized_lens = _tokenize_fn( |
|
|
|
|
|
|
|
[header] + [s["value"] for s in source], |
|
tokenizer, |
|
)["input_ids_lens"] |
|
|
|
|
|
speakers = [sentence["from"] for sentence in source] |
|
_mask_targets(target, tokenized_lens, speakers) |
|
|
|
return dict(input_ids=input_ids, labels=targets) |
|
|
|
|
|
class LazySupervisedDataset(Dataset): |
|
"""Dataset for supervised fine-tuning.""" |
|
|
|
def __init__( |
|
self, |
|
data_path: str, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
|
data_args, |
|
) -> None: |
|
super(LazySupervisedDataset, self).__init__() |
|
list_data_dict = json.load(open(data_path, "r")) |
|
|
|
self.tokenizer = tokenizer |
|
|
|
self.list_data_dict = list_data_dict |
|
|
|
self.data_args = data_args |
|
|
|
@property |
|
|
|
def lengths(self): |
|
length_list = [] |
|
for sample in self.list_data_dict: |
|
img_tokens = 128 if "image" in sample else 0 |
|
length_list.append( |
|
sum(len(conv["value"].split()) for conv in sample["conversations"]) |
|
+ img_tokens |
|
) |
|
return length_list |
|
|
|
@property |
|
def modality_lengths(self) -> List[int]: |
|
length_list = [] |
|
for sample in self.list_data_dict: |
|
cur_len = sum( |
|
len(conv["value"].split()) for conv in sample["conversations"] |
|
) |
|
cur_len = ( |
|
cur_len if ("image" in sample) or ("video" in sample) else -cur_len |
|
) |
|
length_list.append(cur_len) |
|
return length_list |
|
|
|
def __len__(self) -> int: |
|
return len(self.list_data_dict) |
|
|
|
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]: |
|
sources = self.list_data_dict[i] |
|
if isinstance(i, int): |
|
sources = [sources] |
|
assert len(sources) == 1, "Don't know why it is wrapped to a list" |
|
has_image = True |
|
if "image" in sources[0]: |
|
image_file = self.list_data_dict[i]["image"] |
|
image_folder = self.data_args.image_folder |
|
processor = self.data_args.image_processor |
|
full_path = os.path.join(image_folder, image_file) |
|
if not os.path.exists(full_path): |
|
print(full_path) |
|
has_image = False |
|
sources = copy.deepcopy([e["conversations"] for e in sources]) |
|
else: |
|
image = Image.open(full_path).convert("RGB") |
|
if self.data_args.image_aspect_ratio == "sam": |
|
image = np.array(image)[:, :, ::-1] |
|
if self.data_args.image_aspect_ratio == "pad": |
|
|
|
|
|
def expand2square(pil_img, background_color): |
|
width, height = pil_img.size |
|
if width == height: |
|
return pil_img |
|
elif width > height: |
|
result = Image.new( |
|
pil_img.mode, (width, width), background_color |
|
) |
|
result.paste(pil_img, (0, (width - height) // 2)) |
|
return result |
|
else: |
|
result = Image.new( |
|
pil_img.mode, (height, height), background_color |
|
) |
|
result.paste(pil_img, ((height - width) // 2, 0)) |
|
return result |
|
|
|
image = expand2square( |
|
image, tuple(int(x * 255) for x in processor.image_mean) |
|
) |
|
image = processor.preprocess(image, return_tensors="pt")[ |
|
"pixel_values" |
|
][0] |
|
else: |
|
if self.data_args.image_aspect_ratio != "sam": |
|
image = processor.preprocess(image, return_tensors="pt")[ |
|
"pixel_values" |
|
][0] |
|
sources = preprocess_multimodal( |
|
copy.deepcopy([e["conversations"] for e in sources]), self.data_args |
|
) |
|
elif "video" in sources[0]: |
|
video_file = self.list_data_dict[i]["video"] |
|
video_folder = self.data_args.image_folder |
|
if "webvid" in video_folder: |
|
video_file = os.path.join(video_folder, "videos", video_file) |
|
elif "ActivityNet" in video_folder: |
|
video_file = os.path.join(video_folder, "train_val", video_file) |
|
else: |
|
video_file = os.path.join(video_folder, video_file) |
|
if not os.path.exists(video_file): |
|
print("nonexist: {}".format(video_file), flush=True) |
|
for sub_folder in os.listdir(video_folder): |
|
if os.path.isdir(os.path.join(video_folder, sub_folder)): |
|
for sub_sub_folder in os.listdir( |
|
os.path.join(video_folder, sub_folder) |
|
): |
|
print("folder", sub_folder, sub_sub_folder) |
|
has_image = False |
|
sources = copy.deepcopy([e["conversations"] for e in sources]) |
|
else: |
|
if video_file.endswith(".webm"): |
|
has_image = False |
|
sources = copy.deepcopy([e["conversations"] for e in sources]) |
|
else: |
|
try: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vr = VideoReader(video_file, ctx=cpu(0), num_threads=1) |
|
sample_fps = round(vr.get_avg_fps() / self.data_args.video_fps) |
|
frame_idx = [i for i in range(0, len(vr), sample_fps)] |
|
video = vr.get_batch(frame_idx).asnumpy() |
|
if self.data_args.image_aspect_ratio == "sam": |
|
image = video[:, :, :, ::-1][:100] |
|
else: |
|
processor = self.data_args.image_processor |
|
image = processor.preprocess(video, return_tensors="pt")[ |
|
"pixel_values" |
|
] |
|
sources = preprocess_multimodal( |
|
copy.deepcopy([e["conversations"] for e in sources]), |
|
self.data_args, |
|
) |
|
except: |
|
has_image = False |
|
sources = copy.deepcopy([e["conversations"] for e in sources]) |
|
else: |
|
has_image = False |
|
sources = copy.deepcopy([e["conversations"] for e in sources]) |
|
data_dict = preprocess( |
|
|
|
|
|
sources, |
|
self.tokenizer, |
|
has_image=has_image, |
|
) |
|
if isinstance(i, int): |
|
data_dict = dict( |
|
input_ids=data_dict["input_ids"][0], labels=data_dict["labels"][0] |
|
) |
|
|
|
|
|
if has_image: |
|
if "image" in self.list_data_dict[i]: |
|
|
|
data_dict["image"] = image |
|
elif "video" in self.list_data_dict[i]: |
|
|
|
data_dict["image"] = image |
|
elif self.data_args.is_multimodal: |
|
|
|
|
|
|
|
if self.data_args.image_aspect_ratio == "sam": |
|
if "video" in self.list_data_dict[i]: |
|
data_dict["image"] = np.zeros((1, 1024, 1024, 3)).astype(np.uint8) |
|
else: |
|
data_dict["image"] = np.zeros((1024, 1024, 3)).astype(np.uint8) |
|
else: |
|
crop_size = self.data_args.image_processor.crop_size |
|
if "video" in self.list_data_dict[i]: |
|
data_dict["image"] = torch.zeros( |
|
1, 3, crop_size["height"], crop_size["width"] |
|
) |
|
else: |
|
data_dict["image"] = torch.zeros( |
|
3, crop_size["height"], crop_size["width"] |
|
) |
|
|
|
if has_image: |
|
if self.data_args.num_points > 0: |
|
if "box" in self.list_data_dict[i]: |
|
x1, y1, x2, y2 = self.list_data_dict[i]["box"] |
|
points = [] |
|
x = random.uniform(x1, x2) |
|
y = random.uniform(y1, y2) |
|
points.append(torch.tensor([x, y, 1])) |
|
for _ in range(1, self.data_args.num_points): |
|
points.append(torch.tensor([0, 0, 0])) |
|
points = torch.stack(points, dim=0) |
|
data_dict["point"] = points |
|
else: |
|
if "point" in self.list_data_dict[i]: |
|
points = torch.tensor(self.list_data_dict[i]["point"]) |
|
data_dict["point"] = points |
|
else: |
|
points = [] |
|
grid = int(np.sqrt(self.data_args.num_points)) |
|
height, width = image.shape[0], image.shape[1] |
|
for i in range(grid): |
|
for j in range(grid): |
|
points.append( |
|
torch.tensor( |
|
[ |
|
width / grid / 2.0 + i / grid * width, |
|
height / grid / 2.0 + j / grid * height, |
|
1, |
|
] |
|
) |
|
) |
|
points = torch.stack(points, dim=0) |
|
data_dict["point"] = points |
|
elif self.data_args.is_multimodal: |
|
if self.data_args.num_points > 0: |
|
points = [] |
|
grid = int(np.sqrt(self.data_args.num_points)) |
|
height, width = data_dict["image"].shape[0], data_dict["image"].shape[1] |
|
for i in range(grid): |
|
for j in range(grid): |
|
points.append( |
|
torch.tensor( |
|
[ |
|
width / grid / 2.0 + i / grid * width, |
|
height / grid / 2.0 + j / grid * height, |
|
1, |
|
] |
|
) |
|
) |
|
points = torch.stack(points, dim=0) |
|
data_dict["point"] = points |
|
|
|
return data_dict |
|
|
|
|
|
@dataclass |
|
class DataCollatorForSupervisedDataset(object): |
|
"""Collate examples for supervised fine-tuning.""" |
|
|
|
tokenizer: transformers.PreTrainedTokenizer |
|
|
|
|
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
input_ids, labels = tuple( |
|
[instance[key] for instance in instances] for key in ("input_ids", "labels") |
|
) |
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
input_ids, |
|
batch_first=True, |
|
|
|
padding_value=self.tokenizer.pad_token_id, |
|
) |
|
labels = torch.nn.utils.rnn.pad_sequence( |
|
labels, batch_first=True, padding_value=IGNORE_INDEX |
|
) |
|
input_ids = input_ids[:, : self.tokenizer.model_max_length] |
|
labels = labels[:, : self.tokenizer.model_max_length] |
|
batch = dict( |
|
input_ids=input_ids, |
|
labels=labels, |
|
|
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "image" in instances[0]: |
|
images = [instance["image"] for instance in instances] |
|
|
|
batch["images"] = images |
|
|
|
if "point" in instances[0]: |
|
points = [instance["point"] for instance in instances] |
|
batch["points"] = torch.stack(points) |
|
|
|
return batch |
|
|
|
|
|
def make_supervised_data_module( |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
|
|
data_args, |
|
|
|
|
|
) -> Dict: |
|
"""Make dataset and collator for supervised fine-tuning.""" |
|
train_dataset = LazySupervisedDataset( |
|
tokenizer=tokenizer, data_path=data_args.data_path, data_args=data_args |
|
) |
|
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) |
|
return dict( |
|
train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator |
|
) |
|
|