|
|
|
|
|
|
|
from safetensors import safe_open |
|
from safetensors.torch import load_model, save_model, load_file |
|
|
|
|
|
from collections import defaultdict |
|
import copy |
|
import json |
|
import os |
|
from os.path import exists, join, isdir |
|
from dataclasses import dataclass, field |
|
import sys |
|
from typing import Optional, Dict, Sequence |
|
import numpy as np |
|
from tqdm import tqdm |
|
import logging |
|
|
|
import torch |
|
import transformers |
|
from torch.nn.utils.rnn import pad_sequence |
|
import argparse |
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForCausalLM, |
|
LineByLineTextDataset, |
|
set_seed, |
|
Seq2SeqTrainer, |
|
Trainer, |
|
LlamaTokenizerFast |
|
) |
|
|
|
from trl import SFTTrainer |
|
from datasets import load_dataset |
|
import evaluate |
|
|
|
from peft import ( |
|
LoraConfig, |
|
get_peft_model_state_dict, |
|
set_peft_model_state_dict, |
|
PeftModel |
|
) |
|
from peft.tuners.lora import LoraLayer |
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR |
|
from auto_gptq.utils.peft_utils import get_gptq_peft_model, GPTQLoraConfig |
|
from auto_gptq import AutoGPTQForCausalLM |
|
from auto_gptq.nn_modules.qlinear import GeneralQuantLinear |
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
IGNORE_INDEX = -100 |
|
DEFAULT_PAD_TOKEN = "[PAD]" |
|
|
|
import os |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
|
def prepare_model_for_int8_training(model, use_gradient_checkpointing=True): |
|
r""" |
|
This method wraps the entire protocol for preparing a model before running a training. This includes: |
|
1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm |
|
head to fp32 |
|
|
|
Args: |
|
model, (`transformers.PreTrainedModel`): |
|
The loaded model from `transformers` |
|
""" |
|
for name, param in model.named_parameters(): |
|
|
|
param.requires_grad = False |
|
|
|
if use_gradient_checkpointing: |
|
|
|
if hasattr(model, "enable_input_require_grads"): |
|
model.enable_input_require_grads() |
|
else: |
|
|
|
def make_inputs_require_grad(module, input, output): |
|
output.requires_grad_(True) |
|
|
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
model.gradient_checkpointing_enable() |
|
|
|
return model |
|
|
|
@dataclass |
|
class ModelArguments: |
|
model_path: Optional[str] = field( |
|
default="./src/" |
|
) |
|
src_lora_path: Optional[str] = field( |
|
default=None, |
|
) |
|
trust_remote_code: Optional[bool] = field( |
|
default=False, |
|
metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."} |
|
) |
|
|
|
@dataclass |
|
class DataArguments: |
|
eval_dataset_size: int = field( |
|
default=1024, metadata={"help": "Size of validation dataset."} |
|
) |
|
max_train_samples: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "For debugging purposes or quicker training, truncate the number of training examples to this " |
|
"value if set." |
|
}, |
|
) |
|
offload_folder: Optional[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Offload folder " |
|
"value if set." |
|
}, |
|
) |
|
max_eval_samples: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this " |
|
"value if set." |
|
}, |
|
) |
|
source_max_len: int = field( |
|
default=1024, |
|
metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."}, |
|
) |
|
target_max_len: int = field( |
|
default=1024, |
|
metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."}, |
|
) |
|
dataset: str = field( |
|
default='alpaca', |
|
metadata={"help": "Which dataset to finetune on. See datamodule for options."} |
|
) |
|
|
|
@dataclass |
|
class TrainingArguments(transformers.Seq2SeqTrainingArguments): |
|
cache_dir: Optional[str] = field( |
|
default=None |
|
) |
|
train_on_source: Optional[bool] = field( |
|
default=False, |
|
metadata={"help": "Whether to train on the input in addition to the target text."} |
|
) |
|
mmlu_split: Optional[str] = field( |
|
default='eval', |
|
metadata={"help": "The MMLU split to run on"} |
|
) |
|
mmlu_dataset: Optional[str] = field( |
|
default='mmlu-fs', |
|
metadata={"help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot."} |
|
) |
|
do_mmlu_eval: Optional[bool] = field( |
|
default=False, |
|
metadata={"help": "Whether to run the MMLU evaluation."} |
|
) |
|
max_mmlu_samples: Optional[int] = field( |
|
default=None, |
|
metadata={"help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset."} |
|
) |
|
mmlu_source_max_len: int = field( |
|
default=2048, |
|
metadata={"help": "Maximum source sequence length for mmlu."} |
|
) |
|
full_finetune: bool = field( |
|
default=False, |
|
metadata={"help": "Finetune the entire model without adapters."} |
|
) |
|
adam8bit: bool = field( |
|
default=False, |
|
metadata={"help": "Use 8-bit adam."} |
|
) |
|
lora_r: int = field( |
|
default=64, |
|
metadata={"help": "Lora R dimension."} |
|
) |
|
lora_alpha: float = field( |
|
default=16, |
|
metadata={"help": " Lora alpha."} |
|
) |
|
lora_dropout: float = field( |
|
default=0.0, |
|
metadata={"help":"Lora dropout."} |
|
) |
|
max_memory_MB: int = field( |
|
default=24000, |
|
metadata={"help": "Free memory per gpu."} |
|
) |
|
report_to: str = field( |
|
default='none', |
|
metadata={"help": "To use wandb or something else for reporting."} |
|
) |
|
output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'}) |
|
optim: str = field(default='paged_adamw_32bit', metadata={"help": 'The optimizer to be used'}) |
|
per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'}) |
|
gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'}) |
|
max_steps: int = field(default=0, metadata={"help": 'How many optimizer update steps to take'}) |
|
weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'}) |
|
learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate'}) |
|
remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'}) |
|
max_grad_norm: float = field(default=0.3, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'}) |
|
gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'}) |
|
do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'}) |
|
lr_scheduler_type: str = field(default='constant', metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'}) |
|
warmup_ratio: float = field(default=0.03, metadata={"help": 'Fraction of steps to do a warmup for'}) |
|
logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'}) |
|
group_by_length: bool = field(default=True, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'}) |
|
save_strategy: str = field(default='steps', metadata={"help": 'When to save checkpoints'}) |
|
save_steps: int = field(default=250, metadata={"help": 'How often to save a model'}) |
|
save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'}) |
|
|
|
def find_all_linear_names(args, model): |
|
cls = GeneralQuantLinear if not(args.full_finetune) else torch.nn.Linear |
|
lora_module_names = set() |
|
for name, module in model.named_modules(): |
|
if isinstance(module, cls): |
|
names = name.split('.') |
|
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) |
|
|
|
|
|
if 'lm_head' in lora_module_names: |
|
lora_module_names.remove('lm_head') |
|
return list(lora_module_names) |
|
|
|
|
|
class SavePeftModelCallback(transformers.TrainerCallback): |
|
def save_model(self, args, state, kwargs): |
|
print('Saving PEFT checkpoint...') |
|
if state.best_model_checkpoint is not None: |
|
checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model") |
|
else: |
|
checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}") |
|
|
|
peft_model_path = os.path.join(checkpoint_folder, "adapter_model") |
|
kwargs["model"].save_pretrained(peft_model_path) |
|
|
|
pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin") |
|
if os.path.exists(pytorch_model_path): |
|
os.remove(pytorch_model_path) |
|
|
|
def on_save(self, args, state, control, **kwargs): |
|
self.save_model(args, state, kwargs) |
|
return control |
|
|
|
def on_train_end(self, args, state, control, **kwargs): |
|
def touch(fname, times=None): |
|
with open(fname, 'a'): |
|
os.utime(fname, times) |
|
|
|
touch(join(args.output_dir, 'completed')) |
|
self.save_model(args, state, kwargs) |
|
|
|
def get_accelerate_model(args, checkpoint_dir): |
|
|
|
n_gpus = torch.cuda.device_count() |
|
max_memory = f'{args.max_memory_MB}MB' |
|
max_memory = {i: max_memory for i in range(n_gpus)} |
|
|
|
if args.full_finetune: assert args.bits in [16, 32] |
|
|
|
print(f'loading base model {args.model_path}...') |
|
model = AutoGPTQForCausalLM.from_quantized( |
|
args.model_path, |
|
low_cpu_mem_usage=True, |
|
device_map='auto', |
|
max_memory=max_memory, |
|
trust_remote_code=args.trust_remote_code, |
|
inject_fused_attention = True, |
|
inject_fused_mlp = False, |
|
use_triton=False, |
|
warmup_triton=False, |
|
offload_folder='offload', |
|
trainable=True |
|
) |
|
model.model.quantize_config = model.quantize_config |
|
model.train() |
|
|
|
setattr(model, 'model_parallel', True) |
|
setattr(model, 'is_parallelizable', True) |
|
modules = find_all_linear_names(args, model) |
|
|
|
print("Modules: ", modules) |
|
|
|
model.config.torch_dtype=torch.float16 |
|
|
|
if not args.full_finetune: |
|
model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing) |
|
if args.gradient_checkpointing: |
|
model.gradient_checkpointing_enable() |
|
|
|
config = GPTQLoraConfig( |
|
r=args.lora_r, |
|
lora_alpha=args.lora_alpha, |
|
target_modules=modules, |
|
lora_dropout=args.lora_dropout, |
|
bias="none", |
|
task_type="CAUSAL_LM", |
|
) |
|
if not args.full_finetune: |
|
if checkpoint_dir is not None: |
|
print("Loading adapters from checkpoint.") |
|
model = PeftModel.from_pretrained(model, join(checkpoint_dir, 'adapter_model')) |
|
for name, p in model.named_parameters(): |
|
if 'lora' in name: |
|
print(name, p.sum()) |
|
else: |
|
print(f'adding LoRA modules...') |
|
model = get_gptq_peft_model(model, config, auto_find_all_linears=True, train_mode=True) |
|
|
|
if args.gradient_checkpointing: |
|
if hasattr(model, "enable_input_require_grads"): |
|
model.enable_input_require_grads() |
|
else: |
|
def make_inputs_require_grad(module, input, output): |
|
output.requires_grad_(True) |
|
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
|
|
|
|
|
for name, module in model.named_modules(): |
|
if isinstance(module, LoraLayer): |
|
if args.bf16: |
|
module = module.to(torch.bfloat16) |
|
if 'norm' in name: |
|
module = module.to(torch.float32) |
|
if 'lm_head' in name or 'embed_tokens' in name: |
|
if hasattr(module, 'weight'): |
|
if args.bf16 and module.weight.dtype == torch.float32: |
|
module = module.to(torch.bfloat16) |
|
return model |
|
|
|
def print_trainable_parameters(args, model): |
|
""" |
|
Prints the number of trainable parameters in the model. |
|
""" |
|
trainable_params = 0 |
|
all_param = 0 |
|
for _, param in model.named_parameters(): |
|
all_param += param.numel() |
|
if param.requires_grad: |
|
trainable_params += param.numel() |
|
try: |
|
trainable_params /= (32//model.quantize_config.bits) |
|
except: |
|
pass |
|
print(f"trainable params: {trainable_params} || all params: {all_param} || trainable: {100 * trainable_params / all_param}") |
|
|
|
def smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict: Dict, |
|
tokenizer: transformers.PreTrainedTokenizer, |
|
model: transformers.PreTrainedModel, |
|
): |
|
"""Resize tokenizer and embedding. |
|
|
|
Note: This is the unoptimized version that may make your embedding size not be divisible by 64. |
|
""" |
|
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) |
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
if num_new_tokens > 0: |
|
input_embeddings = model.get_input_embeddings().weight.data |
|
output_embeddings = model.get_output_embeddings().weight.data |
|
|
|
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) |
|
|
|
input_embeddings[-num_new_tokens:] = input_embeddings_avg |
|
output_embeddings[-num_new_tokens:] = output_embeddings_avg |
|
|
|
@dataclass |
|
class DataCollatorForCausalLM(object): |
|
tokenizer: transformers.PreTrainedTokenizer |
|
source_max_len: int |
|
target_max_len: int |
|
train_on_source: bool |
|
predict_with_generate: bool |
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
|
|
sources = [example['input'] for example in instances] |
|
targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances] |
|
|
|
tokenized_sources_with_prompt = self.tokenizer( |
|
sources, |
|
max_length=self.source_max_len, |
|
truncation=True, |
|
) |
|
tokenized_targets = self.tokenizer( |
|
targets, |
|
max_length=self.target_max_len, |
|
truncation=True, |
|
add_special_tokens=False, |
|
) |
|
|
|
input_ids = [] |
|
labels = [] |
|
for tokenized_source, tokenized_target in zip( |
|
tokenized_sources_with_prompt['input_ids'], |
|
tokenized_targets['input_ids'] |
|
): |
|
if not self.predict_with_generate: |
|
input_ids.append(torch.tensor(tokenized_source + tokenized_target)) |
|
if not self.train_on_source: |
|
labels.append( |
|
torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target)) |
|
) |
|
else: |
|
labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target))) |
|
else: |
|
input_ids.append(torch.tensor(tokenized_source)) |
|
|
|
input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) |
|
labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None |
|
data_dict = { |
|
'input_ids': input_ids, |
|
'attention_mask':input_ids.ne(self.tokenizer.pad_token_id), |
|
} |
|
if labels is not None: |
|
data_dict['labels'] = labels |
|
return data_dict |
|
|
|
def extract_unnatural_instructions_data(examples, extract_reformulations=False): |
|
out = { |
|
'input': [], |
|
'output': [], |
|
} |
|
for example_instances in examples['instances']: |
|
for instance in example_instances: |
|
out['input'].append(instance['instruction_with_input']) |
|
out['output'].append(instance['output']) |
|
if extract_reformulations: |
|
for example_reformulations in examples['reformulations']: |
|
if example_reformulations is not None: |
|
for instance in example_reformulations: |
|
out['input'].append(instance['instruction_with_input']) |
|
out['output'].append(instance['output']) |
|
return out |
|
|
|
def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) -> Dict: |
|
|
|
print(args.dataset) |
|
|
|
if args.dataset == 'txt': |
|
from transformers import TextDataset |
|
with open("txt.txt","r",encoding="utf-8") as f: |
|
data = f.readlines() |
|
|
|
tmp = '' |
|
gdata = [] |
|
current_length = 0 |
|
print("Creating groups...") |
|
for s in data: |
|
if current_length + len(s) <= 512: |
|
tmp = tmp + s + "\n" |
|
current_length += len(s) |
|
else: |
|
gdata.append(tmp) |
|
tmp = s |
|
current_length = len(s) |
|
|
|
l = list(map(lambda x: { |
|
'input': '', |
|
'output': x |
|
}, gdata)) |
|
from datasets import Dataset |
|
dataset=Dataset.from_list(l) |
|
|
|
elif args.dataset == 'dataset': |
|
dataset = load_dataset("json", data_files='dataset.json') |
|
|
|
if args.do_train: |
|
if args.dataset == 'txt': |
|
train_dataset = dataset |
|
else: |
|
train_dataset = dataset['train'] |
|
if args.max_train_samples is not None and len(train_dataset) > args.max_train_samples: |
|
train_dataset = train_dataset.select(range(args.max_train_samples)) |
|
if args.group_by_length: |
|
train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])}) |
|
|
|
data_collator = DataCollatorForCausalLM( |
|
tokenizer=tokenizer, |
|
source_max_len=args.source_max_len, |
|
target_max_len=args.target_max_len, |
|
train_on_source=args.train_on_source, |
|
predict_with_generate=args.predict_with_generate, |
|
) |
|
return dict( |
|
train_dataset=train_dataset if args.do_train else None, |
|
eval_dataset=eval_dataset if args.do_eval else None, |
|
predict_dataset=eval_dataset if args.do_predict else None, |
|
data_collator=data_collator |
|
) |
|
|
|
def get_last_checkpoint(checkpoint_dir): |
|
if isdir(checkpoint_dir): |
|
is_completed = exists(join(checkpoint_dir, 'completed')) |
|
if is_completed: return None, True |
|
max_step = 0 |
|
for filename in os.listdir(checkpoint_dir): |
|
if isdir(join(checkpoint_dir, filename)) and filename.startswith('checkpoint'): |
|
max_step = max(max_step, int(filename.replace('checkpoint-', ''))) |
|
print("MX: ", max_step, " - ", filename) |
|
if max_step == 0: return None, is_completed |
|
checkpoint_dir = join(checkpoint_dir, f'checkpoint-{max_step}') |
|
print(f"Found a previous checkpoint at: {checkpoint_dir}") |
|
return checkpoint_dir, is_completed |
|
return None, False |
|
|
|
def train(): |
|
hfparser = transformers.HfArgumentParser(( |
|
ModelArguments, DataArguments, TrainingArguments |
|
)) |
|
model_args, data_args, training_args, extra_args = \ |
|
hfparser.parse_args_into_dataclasses(return_remaining_strings=True) |
|
|
|
args = argparse.Namespace( |
|
**vars(model_args), **vars(data_args), **vars(training_args) |
|
) |
|
|
|
checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir) |
|
|
|
if completed_training: |
|
print('Detected that training was already completed!') |
|
|
|
model = get_accelerate_model(args, checkpoint_dir) |
|
training_args.skip_loading_checkpoint_weights=True |
|
|
|
load_existing_lora = os.path.exists('src_lora/adapter_model.safetensors') |
|
|
|
if load_existing_lora: |
|
print(f"Loading existing LoRA") |
|
adapters_weights = load_file('src_lora/adapter_model.safetensors') |
|
set_peft_model_state_dict(model, adapters_weights) |
|
|
|
model.config.use_cache = False |
|
print_trainable_parameters(args, model) |
|
print('loaded model') |
|
set_seed(args.seed) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
args.model_path, |
|
cache_dir=args.cache_dir, |
|
padding_side="right", |
|
use_fast=True, |
|
) |
|
|
|
if tokenizer.pad_token is None: |
|
smart_tokenizer_and_embedding_resize( |
|
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), |
|
tokenizer=tokenizer, |
|
model=model, |
|
) |
|
|
|
if isinstance(tokenizer, LlamaTokenizerFast): |
|
|
|
|
|
|
|
|
|
tokenizer.add_special_tokens( |
|
{ |
|
"eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id), |
|
"bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id), |
|
"unk_token": tokenizer.convert_ids_to_tokens(model.config.pad_token_id), |
|
} |
|
) |
|
|
|
data_module = make_data_module(tokenizer=tokenizer, args=args) |
|
trainer = Seq2SeqTrainer( |
|
|
|
model=model, |
|
tokenizer=tokenizer, |
|
args=training_args, |
|
**{k:v for k,v in data_module.items() if k != 'predict_dataset'}, |
|
) |
|
|
|
|
|
if not args.full_finetune: |
|
trainer.add_callback(SavePeftModelCallback) |
|
|
|
|
|
dtypes = {} |
|
for _, p in model.named_parameters(): |
|
dtype = p.dtype |
|
if dtype not in dtypes: dtypes[dtype] = 0 |
|
dtypes[dtype] += p.numel() |
|
total = 0 |
|
for k, v in dtypes.items(): total+= v |
|
for k, v in dtypes.items(): |
|
print(k, v, v/total) |
|
|
|
all_metrics = {"run_name": args.run_name} |
|
|
|
if args.do_train: |
|
train_result = trainer.train(resume_from_checkpoint=False) |
|
metrics = train_result.metrics |
|
trainer.log_metrics("train", metrics) |
|
trainer.save_metrics("train", metrics) |
|
trainer.save_state() |
|
all_metrics.update(metrics) |
|
|
|
if (args.do_train): |
|
with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout: |
|
fout.write(json.dumps(all_metrics)) |
|
|
|
if __name__ == "__main__": |
|
train() |
|
|