zlsl commited on
Commit
e714177
1 Parent(s): 9145ae9

Upload _gptqlora.py

Browse files
Files changed (1) hide show
  1. _gptqlora.py +613 -0
_gptqlora.py ADDED
@@ -0,0 +1,613 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This source code is licensed under the MIT license found in the
2
+ # LICENSE file in the root directory of this source tree.
3
+
4
+ from safetensors import safe_open
5
+ from safetensors.torch import load_model, save_model, load_file
6
+
7
+
8
+ from collections import defaultdict
9
+ import copy
10
+ import json
11
+ import os
12
+ from os.path import exists, join, isdir
13
+ from dataclasses import dataclass, field
14
+ import sys
15
+ from typing import Optional, Dict, Sequence
16
+ import numpy as np
17
+ from tqdm import tqdm
18
+ import logging
19
+
20
+ import torch
21
+ import transformers
22
+ from torch.nn.utils.rnn import pad_sequence
23
+ import argparse
24
+ from transformers import (
25
+ AutoTokenizer,
26
+ AutoModelForCausalLM,
27
+ LineByLineTextDataset,
28
+ set_seed,
29
+ Seq2SeqTrainer,
30
+ Trainer,
31
+ LlamaTokenizerFast
32
+ )
33
+
34
+ from trl import SFTTrainer
35
+ from datasets import load_dataset
36
+ import evaluate
37
+
38
+ from peft import (
39
+ LoraConfig,
40
+ get_peft_model_state_dict,
41
+ set_peft_model_state_dict,
42
+ PeftModel
43
+ )
44
+ from peft.tuners.lora import LoraLayer
45
+ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
46
+ from auto_gptq.utils.peft_utils import get_gptq_peft_model, GPTQLoraConfig
47
+ from auto_gptq import AutoGPTQForCausalLM
48
+ from auto_gptq.nn_modules.qlinear import GeneralQuantLinear
49
+
50
+ torch.backends.cuda.matmul.allow_tf32 = True
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+ IGNORE_INDEX = -100
55
+ DEFAULT_PAD_TOKEN = "[PAD]"
56
+
57
+ import os
58
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
59
+
60
+ def prepare_model_for_int8_training(model, use_gradient_checkpointing=True):
61
+ r"""
62
+ This method wraps the entire protocol for preparing a model before running a training. This includes:
63
+ 1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
64
+ head to fp32
65
+
66
+ Args:
67
+ model, (`transformers.PreTrainedModel`):
68
+ The loaded model from `transformers`
69
+ """
70
+ for name, param in model.named_parameters():
71
+ # freeze base model's layers
72
+ param.requires_grad = False
73
+
74
+ if use_gradient_checkpointing:
75
+ # For backward compatibility
76
+ if hasattr(model, "enable_input_require_grads"):
77
+ model.enable_input_require_grads()
78
+ else:
79
+
80
+ def make_inputs_require_grad(module, input, output):
81
+ output.requires_grad_(True)
82
+
83
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
84
+
85
+ # enable gradient checkpointing for memory efficiency
86
+ model.gradient_checkpointing_enable()
87
+
88
+ return model
89
+
90
+ @dataclass
91
+ class ModelArguments:
92
+ model_path: Optional[str] = field(
93
+ default="./src/"
94
+ )
95
+ src_lora_path: Optional[str] = field(
96
+ default=None,
97
+ )
98
+ trust_remote_code: Optional[bool] = field(
99
+ default=False,
100
+ metadata={"help": "Enable unpickling of arbitrary code in AutoModelForCausalLM#from_pretrained."}
101
+ )
102
+
103
+ @dataclass
104
+ class DataArguments:
105
+ eval_dataset_size: int = field(
106
+ default=1024, metadata={"help": "Size of validation dataset."}
107
+ )
108
+ max_train_samples: Optional[int] = field(
109
+ default=None,
110
+ metadata={
111
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
112
+ "value if set."
113
+ },
114
+ )
115
+ offload_folder: Optional[str] = field(
116
+ default=None,
117
+ metadata={
118
+ "help": "Offload folder "
119
+ "value if set."
120
+ },
121
+ )
122
+ max_eval_samples: Optional[int] = field(
123
+ default=None,
124
+ metadata={
125
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
126
+ "value if set."
127
+ },
128
+ )
129
+ source_max_len: int = field(
130
+ default=1024,
131
+ metadata={"help": "Maximum source sequence length. Sequences will be right padded (and possibly truncated)."},
132
+ )
133
+ target_max_len: int = field(
134
+ default=1024,
135
+ metadata={"help": "Maximum target sequence length. Sequences will be right padded (and possibly truncated)."},
136
+ )
137
+ dataset: str = field(
138
+ default='alpaca',
139
+ metadata={"help": "Which dataset to finetune on. See datamodule for options."}
140
+ )
141
+
142
+ @dataclass
143
+ class TrainingArguments(transformers.Seq2SeqTrainingArguments):
144
+ cache_dir: Optional[str] = field(
145
+ default=None
146
+ )
147
+ train_on_source: Optional[bool] = field(
148
+ default=False,
149
+ metadata={"help": "Whether to train on the input in addition to the target text."}
150
+ )
151
+ mmlu_split: Optional[str] = field(
152
+ default='eval',
153
+ metadata={"help": "The MMLU split to run on"}
154
+ )
155
+ mmlu_dataset: Optional[str] = field(
156
+ default='mmlu-fs',
157
+ metadata={"help": "MMLU dataset to use: options are `mmlu-zs` for zero-shot or `mmlu-fs` for few shot."}
158
+ )
159
+ do_mmlu_eval: Optional[bool] = field(
160
+ default=False,
161
+ metadata={"help": "Whether to run the MMLU evaluation."}
162
+ )
163
+ max_mmlu_samples: Optional[int] = field(
164
+ default=None,
165
+ metadata={"help": "If set, only evaluates on `max_mmlu_samples` of the MMMLU dataset."}
166
+ )
167
+ mmlu_source_max_len: int = field(
168
+ default=2048,
169
+ metadata={"help": "Maximum source sequence length for mmlu."}
170
+ )
171
+ full_finetune: bool = field(
172
+ default=False,
173
+ metadata={"help": "Finetune the entire model without adapters."}
174
+ )
175
+ adam8bit: bool = field(
176
+ default=False,
177
+ metadata={"help": "Use 8-bit adam."}
178
+ )
179
+ lora_r: int = field(
180
+ default=64,
181
+ metadata={"help": "Lora R dimension."}
182
+ )
183
+ lora_alpha: float = field(
184
+ default=16,
185
+ metadata={"help": " Lora alpha."}
186
+ )
187
+ lora_dropout: float = field(
188
+ default=0.0,
189
+ metadata={"help":"Lora dropout."}
190
+ )
191
+ max_memory_MB: int = field(
192
+ default=24000,
193
+ metadata={"help": "Free memory per gpu."}
194
+ )
195
+ report_to: str = field(
196
+ default='none',
197
+ metadata={"help": "To use wandb or something else for reporting."}
198
+ )
199
+ output_dir: str = field(default='./output', metadata={"help": 'The output dir for logs and checkpoints'})
200
+ optim: str = field(default='paged_adamw_32bit', metadata={"help": 'The optimizer to be used'})
201
+ per_device_train_batch_size: int = field(default=1, metadata={"help": 'The training batch size per GPU. Increase for better speed.'})
202
+ gradient_accumulation_steps: int = field(default=16, metadata={"help": 'How many gradients to accumulate before to perform an optimizer step'})
203
+ max_steps: int = field(default=0, metadata={"help": 'How many optimizer update steps to take'})
204
+ weight_decay: float = field(default=0.0, metadata={"help": 'The L2 weight decay rate of AdamW'}) # use lora dropout instead for regularization if needed
205
+ learning_rate: float = field(default=0.0002, metadata={"help": 'The learnign rate'})
206
+ remove_unused_columns: bool = field(default=False, metadata={"help": 'Removed unused columns. Needed to make this codebase work.'})
207
+ max_grad_norm: float = field(default=0.3, metadata={"help": 'Gradient clipping max norm. This is tuned and works well for all models tested.'})
208
+ gradient_checkpointing: bool = field(default=True, metadata={"help": 'Use gradient checkpointing. You want to use this.'})
209
+ do_train: bool = field(default=True, metadata={"help": 'To train or not to train, that is the question?'})
210
+ lr_scheduler_type: str = field(default='constant', metadata={"help": 'Learning rate schedule. Constant a bit better than cosine, and has advantage for analysis'})
211
+ warmup_ratio: float = field(default=0.03, metadata={"help": 'Fraction of steps to do a warmup for'})
212
+ logging_steps: int = field(default=10, metadata={"help": 'The frequency of update steps after which to log the loss'})
213
+ group_by_length: bool = field(default=True, metadata={"help": 'Group sequences into batches with same length. Saves memory and speeds up training considerably.'})
214
+ save_strategy: str = field(default='steps', metadata={"help": 'When to save checkpoints'})
215
+ save_steps: int = field(default=250, metadata={"help": 'How often to save a model'})
216
+ save_total_limit: int = field(default=40, metadata={"help": 'How many checkpoints to save before the oldest is overwritten'})
217
+
218
+ def find_all_linear_names(args, model):
219
+ cls = GeneralQuantLinear if not(args.full_finetune) else torch.nn.Linear
220
+ lora_module_names = set()
221
+ for name, module in model.named_modules():
222
+ if isinstance(module, cls):
223
+ names = name.split('.')
224
+ lora_module_names.add(names[0] if len(names) == 1 else names[-1])
225
+
226
+
227
+ if 'lm_head' in lora_module_names: # needed for 16-bit
228
+ lora_module_names.remove('lm_head')
229
+ return list(lora_module_names)
230
+
231
+
232
+ class SavePeftModelCallback(transformers.TrainerCallback):
233
+ def save_model(self, args, state, kwargs):
234
+ print('Saving PEFT checkpoint...')
235
+ if state.best_model_checkpoint is not None:
236
+ checkpoint_folder = os.path.join(state.best_model_checkpoint, "adapter_model")
237
+ else:
238
+ checkpoint_folder = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
239
+
240
+ peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
241
+ kwargs["model"].save_pretrained(peft_model_path)
242
+
243
+ pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
244
+ if os.path.exists(pytorch_model_path):
245
+ os.remove(pytorch_model_path)
246
+
247
+ def on_save(self, args, state, control, **kwargs):
248
+ self.save_model(args, state, kwargs)
249
+ return control
250
+
251
+ def on_train_end(self, args, state, control, **kwargs):
252
+ def touch(fname, times=None):
253
+ with open(fname, 'a'):
254
+ os.utime(fname, times)
255
+
256
+ touch(join(args.output_dir, 'completed'))
257
+ self.save_model(args, state, kwargs)
258
+
259
+ def get_accelerate_model(args, checkpoint_dir):
260
+
261
+ n_gpus = torch.cuda.device_count()
262
+ max_memory = f'{args.max_memory_MB}MB'
263
+ max_memory = {i: max_memory for i in range(n_gpus)}
264
+
265
+ if args.full_finetune: assert args.bits in [16, 32]
266
+
267
+ print(f'loading base model {args.model_path}...')
268
+ model = AutoGPTQForCausalLM.from_quantized(
269
+ args.model_path,
270
+ low_cpu_mem_usage=True,
271
+ device_map='auto',
272
+ max_memory=max_memory,
273
+ trust_remote_code=args.trust_remote_code,
274
+ inject_fused_attention = True,
275
+ inject_fused_mlp = False,
276
+ use_triton=False,
277
+ warmup_triton=False,
278
+ offload_folder='offload',
279
+ trainable=True
280
+ )
281
+ model.model.quantize_config = model.quantize_config
282
+ model.train()
283
+
284
+ setattr(model, 'model_parallel', True)
285
+ setattr(model, 'is_parallelizable', True)
286
+ modules = find_all_linear_names(args, model)
287
+
288
+ print("Modules: ", modules)
289
+
290
+ model.config.torch_dtype=torch.float16 #if args.fp16 else (torch.bfloat16 if args.bf16 else torch.float32))
291
+
292
+ if not args.full_finetune:
293
+ model = prepare_model_for_int8_training(model, use_gradient_checkpointing=args.gradient_checkpointing)
294
+ if args.gradient_checkpointing:
295
+ model.gradient_checkpointing_enable()
296
+
297
+ config = GPTQLoraConfig(
298
+ r=args.lora_r,
299
+ lora_alpha=args.lora_alpha,
300
+ target_modules=modules,
301
+ lora_dropout=args.lora_dropout,
302
+ bias="none",
303
+ task_type="CAUSAL_LM",
304
+ )
305
+ if not args.full_finetune:
306
+ if checkpoint_dir is not None:
307
+ print("Loading adapters from checkpoint.")
308
+ model = PeftModel.from_pretrained(model, join(checkpoint_dir, 'adapter_model'))
309
+ for name, p in model.named_parameters():
310
+ if 'lora' in name:
311
+ print(name, p.sum())
312
+ else:
313
+ print(f'adding LoRA modules...')
314
+ model = get_gptq_peft_model(model, config, auto_find_all_linears=True, train_mode=True)
315
+
316
+ if args.gradient_checkpointing:
317
+ if hasattr(model, "enable_input_require_grads"):
318
+ model.enable_input_require_grads()
319
+ else:
320
+ def make_inputs_require_grad(module, input, output):
321
+ output.requires_grad_(True)
322
+ model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
323
+
324
+
325
+ for name, module in model.named_modules():
326
+ if isinstance(module, LoraLayer):
327
+ if args.bf16:
328
+ module = module.to(torch.bfloat16)
329
+ if 'norm' in name:
330
+ module = module.to(torch.float32)
331
+ if 'lm_head' in name or 'embed_tokens' in name:
332
+ if hasattr(module, 'weight'):
333
+ if args.bf16 and module.weight.dtype == torch.float32:
334
+ module = module.to(torch.bfloat16)
335
+ return model
336
+
337
+ def print_trainable_parameters(args, model):
338
+ """
339
+ Prints the number of trainable parameters in the model.
340
+ """
341
+ trainable_params = 0
342
+ all_param = 0
343
+ for _, param in model.named_parameters():
344
+ all_param += param.numel()
345
+ if param.requires_grad:
346
+ trainable_params += param.numel()
347
+ try:
348
+ trainable_params /= (32//model.quantize_config.bits)
349
+ except:
350
+ pass
351
+ print(f"trainable params: {trainable_params} || all params: {all_param} || trainable: {100 * trainable_params / all_param}")
352
+
353
+ def smart_tokenizer_and_embedding_resize(
354
+ special_tokens_dict: Dict,
355
+ tokenizer: transformers.PreTrainedTokenizer,
356
+ model: transformers.PreTrainedModel,
357
+ ):
358
+ """Resize tokenizer and embedding.
359
+
360
+ Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
361
+ """
362
+ num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
363
+ model.resize_token_embeddings(len(tokenizer))
364
+
365
+ if num_new_tokens > 0:
366
+ input_embeddings = model.get_input_embeddings().weight.data
367
+ output_embeddings = model.get_output_embeddings().weight.data
368
+
369
+ input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
370
+ output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
371
+
372
+ input_embeddings[-num_new_tokens:] = input_embeddings_avg
373
+ output_embeddings[-num_new_tokens:] = output_embeddings_avg
374
+
375
+ @dataclass
376
+ class DataCollatorForCausalLM(object):
377
+ tokenizer: transformers.PreTrainedTokenizer
378
+ source_max_len: int
379
+ target_max_len: int
380
+ train_on_source: bool
381
+ predict_with_generate: bool
382
+
383
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
384
+ # Extract elements
385
+ sources = [example['input'] for example in instances]
386
+ targets = [f"{example['output']}{self.tokenizer.eos_token}" for example in instances]
387
+ # Tokenize
388
+ tokenized_sources_with_prompt = self.tokenizer(
389
+ sources,
390
+ max_length=self.source_max_len,
391
+ truncation=True,
392
+ )
393
+ tokenized_targets = self.tokenizer(
394
+ targets,
395
+ max_length=self.target_max_len,
396
+ truncation=True,
397
+ add_special_tokens=False,
398
+ )
399
+ # Build the input and labels for causal LM
400
+ input_ids = []
401
+ labels = []
402
+ for tokenized_source, tokenized_target in zip(
403
+ tokenized_sources_with_prompt['input_ids'],
404
+ tokenized_targets['input_ids']
405
+ ):
406
+ if not self.predict_with_generate:
407
+ input_ids.append(torch.tensor(tokenized_source + tokenized_target))
408
+ if not self.train_on_source:
409
+ labels.append(
410
+ torch.tensor([IGNORE_INDEX for _ in range(len(tokenized_source))] + copy.deepcopy(tokenized_target))
411
+ )
412
+ else:
413
+ labels.append(torch.tensor(copy.deepcopy(tokenized_source + tokenized_target)))
414
+ else:
415
+ input_ids.append(torch.tensor(tokenized_source))
416
+ # Apply padding
417
+ input_ids = pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
418
+ labels = pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) if not self.predict_with_generate else None
419
+ data_dict = {
420
+ 'input_ids': input_ids,
421
+ 'attention_mask':input_ids.ne(self.tokenizer.pad_token_id),
422
+ }
423
+ if labels is not None:
424
+ data_dict['labels'] = labels
425
+ return data_dict
426
+
427
+ def extract_unnatural_instructions_data(examples, extract_reformulations=False):
428
+ out = {
429
+ 'input': [],
430
+ 'output': [],
431
+ }
432
+ for example_instances in examples['instances']:
433
+ for instance in example_instances:
434
+ out['input'].append(instance['instruction_with_input'])
435
+ out['output'].append(instance['output'])
436
+ if extract_reformulations:
437
+ for example_reformulations in examples['reformulations']:
438
+ if example_reformulations is not None:
439
+ for instance in example_reformulations:
440
+ out['input'].append(instance['instruction_with_input'])
441
+ out['output'].append(instance['output'])
442
+ return out
443
+
444
+ def make_data_module(tokenizer: transformers.PreTrainedTokenizer, args) -> Dict:
445
+ # Load dataset.
446
+ print(args.dataset)
447
+
448
+ if args.dataset == 'txt':
449
+ from transformers import TextDataset
450
+ with open("txt.txt","r",encoding="utf-8") as f:
451
+ data = f.readlines()
452
+
453
+ tmp = ''
454
+ gdata = []
455
+ current_length = 0
456
+ print("Creating groups...")
457
+ for s in data:
458
+ if current_length + len(s) <= 512:
459
+ tmp = tmp + s + "\n"
460
+ current_length += len(s)
461
+ else:
462
+ gdata.append(tmp)
463
+ tmp = s
464
+ current_length = len(s)
465
+
466
+ l = list(map(lambda x: {
467
+ 'input': '',
468
+ 'output': x
469
+ }, gdata))
470
+ from datasets import Dataset
471
+ dataset=Dataset.from_list(l)
472
+
473
+ elif args.dataset == 'dataset':
474
+ dataset = load_dataset("json", data_files='dataset.json')
475
+
476
+ if args.do_train:
477
+ if args.dataset == 'txt':
478
+ train_dataset = dataset
479
+ else:
480
+ train_dataset = dataset['train']
481
+ if args.max_train_samples is not None and len(train_dataset) > args.max_train_samples:
482
+ train_dataset = train_dataset.select(range(args.max_train_samples))
483
+ if args.group_by_length:
484
+ train_dataset = train_dataset.map(lambda x: {'length': len(x['input']) + len(x['output'])})
485
+
486
+ data_collator = DataCollatorForCausalLM(
487
+ tokenizer=tokenizer,
488
+ source_max_len=args.source_max_len,
489
+ target_max_len=args.target_max_len,
490
+ train_on_source=args.train_on_source,
491
+ predict_with_generate=args.predict_with_generate,
492
+ )
493
+ return dict(
494
+ train_dataset=train_dataset if args.do_train else None,
495
+ eval_dataset=eval_dataset if args.do_eval else None,
496
+ predict_dataset=eval_dataset if args.do_predict else None,
497
+ data_collator=data_collator
498
+ )
499
+
500
+ def get_last_checkpoint(checkpoint_dir):
501
+ if isdir(checkpoint_dir):
502
+ is_completed = exists(join(checkpoint_dir, 'completed'))
503
+ if is_completed: return None, True # already finished
504
+ max_step = 0
505
+ for filename in os.listdir(checkpoint_dir):
506
+ if isdir(join(checkpoint_dir, filename)) and filename.startswith('checkpoint'):
507
+ max_step = max(max_step, int(filename.replace('checkpoint-', '')))
508
+ print("MX: ", max_step, " - ", filename)
509
+ if max_step == 0: return None, is_completed # training started, but no checkpoint
510
+ checkpoint_dir = join(checkpoint_dir, f'checkpoint-{max_step}')
511
+ print(f"Found a previous checkpoint at: {checkpoint_dir}")
512
+ return checkpoint_dir, is_completed # checkpoint found!
513
+ return None, False # first training
514
+
515
+ def train():
516
+ hfparser = transformers.HfArgumentParser((
517
+ ModelArguments, DataArguments, TrainingArguments
518
+ ))
519
+ model_args, data_args, training_args, extra_args = \
520
+ hfparser.parse_args_into_dataclasses(return_remaining_strings=True)
521
+ # training_args.generation_config = transformers.GenerationConfig(**vars(generation_args))
522
+ args = argparse.Namespace(
523
+ **vars(model_args), **vars(data_args), **vars(training_args)
524
+ )
525
+
526
+ checkpoint_dir, completed_training = get_last_checkpoint(args.output_dir)
527
+
528
+ if completed_training:
529
+ print('Detected that training was already completed!')
530
+
531
+ model = get_accelerate_model(args, checkpoint_dir)
532
+ training_args.skip_loading_checkpoint_weights=True
533
+
534
+ load_existing_lora = os.path.exists('src_lora/adapter_model.safetensors')
535
+
536
+ if load_existing_lora:
537
+ print(f"Loading existing LoRA")
538
+ adapters_weights = load_file('src_lora/adapter_model.safetensors')
539
+ set_peft_model_state_dict(model, adapters_weights)
540
+
541
+ model.config.use_cache = False
542
+ print_trainable_parameters(args, model)
543
+ print('loaded model')
544
+ set_seed(args.seed)
545
+
546
+ # Tokenizer
547
+ tokenizer = AutoTokenizer.from_pretrained(
548
+ args.model_path,
549
+ cache_dir=args.cache_dir,
550
+ padding_side="right",
551
+ use_fast=True,
552
+ )
553
+
554
+ if tokenizer.pad_token is None:
555
+ smart_tokenizer_and_embedding_resize(
556
+ special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
557
+ tokenizer=tokenizer,
558
+ model=model,
559
+ )
560
+
561
+ if isinstance(tokenizer, LlamaTokenizerFast):
562
+ # LLaMA tokenizer may not have correct special tokens set.
563
+ # Check and add them if missing to prevent them from being parsed into different tokens.
564
+ # Note that these are present in the vocabulary.
565
+ # Note also that `model.config.pad_token_id` is 0 which corresponds to `<unk>` token.
566
+ tokenizer.add_special_tokens(
567
+ {
568
+ "eos_token": tokenizer.convert_ids_to_tokens(model.config.eos_token_id),
569
+ "bos_token": tokenizer.convert_ids_to_tokens(model.config.bos_token_id),
570
+ "unk_token": tokenizer.convert_ids_to_tokens(model.config.pad_token_id),
571
+ }
572
+ )
573
+
574
+ data_module = make_data_module(tokenizer=tokenizer, args=args)
575
+ trainer = Seq2SeqTrainer(
576
+ # trainer = SFTTrainer(
577
+ model=model,
578
+ tokenizer=tokenizer,
579
+ args=training_args,
580
+ **{k:v for k,v in data_module.items() if k != 'predict_dataset'},
581
+ )
582
+
583
+ # Callbacks
584
+ if not args.full_finetune:
585
+ trainer.add_callback(SavePeftModelCallback)
586
+
587
+ # Verifying the datatypes.
588
+ dtypes = {}
589
+ for _, p in model.named_parameters():
590
+ dtype = p.dtype
591
+ if dtype not in dtypes: dtypes[dtype] = 0
592
+ dtypes[dtype] += p.numel()
593
+ total = 0
594
+ for k, v in dtypes.items(): total+= v
595
+ for k, v in dtypes.items():
596
+ print(k, v, v/total)
597
+
598
+ all_metrics = {"run_name": args.run_name}
599
+ # Training
600
+ if args.do_train:
601
+ train_result = trainer.train(resume_from_checkpoint=False)
602
+ metrics = train_result.metrics
603
+ trainer.log_metrics("train", metrics)
604
+ trainer.save_metrics("train", metrics)
605
+ trainer.save_state()
606
+ all_metrics.update(metrics)
607
+
608
+ if (args.do_train):
609
+ with open(os.path.join(args.output_dir, "metrics.json"), "w") as fout:
610
+ fout.write(json.dumps(all_metrics))
611
+
612
+ if __name__ == "__main__":
613
+ train()