TedYeh commited on
Commit
da060de
1 Parent(s): 2a548f2

add t5 package

Browse files
app.py CHANGED
@@ -1,15 +1,17 @@
1
  import gradio as gr
 
2
  from transformers import AutoTokenizer, T5ForConditionalGeneration
3
- tokenizer = AutoTokenizer.from_pretrained("CodeTed/CGEDit")
4
- model = T5ForConditionalGeneration.from_pretrained("CodeTed/CGEDit")
 
5
 
6
  def cged_correction(sentence, function):
7
  prompt = {"錯別字校正":"糾正句子中的錯字:", "文法校正":"糾正句子中的錯誤:",
8
  "文本重構":"在不改動文意的情況下改寫句子:", "文本簡化":"在不改動文意的情況下改寫句子:", "整體校正":"修改句子的錯誤或使其更通順:"}
9
- input_ids = tokenizer(prompt[function] + sentence, return_tensors="pt").input_ids
10
- outputs = model.generate(input_ids, max_length=200)
11
- edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
12
- return edited_text
13
 
14
  with gr.Blocks() as demo:
15
  gr.Markdown(
 
1
  import gradio as gr
2
+ from t5.t5_model import T5Model
3
  from transformers import AutoTokenizer, T5ForConditionalGeneration
4
+ #tokenizer = AutoTokenizer.from_pretrained("CodeTed/CGEDit")
5
+ #model = T5ForConditionalGeneration.from_pretrained("CodeTed/CGEDit")
6
+ model = T5Model('t5', "CodeTed/CGEDit", args={"eval_batch_size": 1}, cuda_device=-1, evaluate=True)
7
 
8
  def cged_correction(sentence, function):
9
  prompt = {"錯別字校正":"糾正句子中的錯字:", "文法校正":"糾正句子中的錯誤:",
10
  "文本重構":"在不改動文意的情況下改寫句子:", "文本簡化":"在不改動文意的情況下改寫句子:", "整體校正":"修改句子的錯誤或使其更通順:"}
11
+ #input_ids = tokenizer(prompt[function] + sentence, return_tensors="pt").input_ids
12
+ outputs = model.predict([prompt[function] + sentence + "_輸出句:"])
13
+ #edited_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
14
+ return outputs[0]
15
 
16
  with gr.Blocks() as demo:
17
  gr.Markdown(
t5/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description:
5
+ """
6
+ from textgen.config.model_args import T5Args, CopyT5Args
7
+ from textgen.t5.t5_model import T5Model
8
+ from textgen.t5.copyt5_model import CopyT5Model
9
+ from textgen.t5.copyt5_utils import ZHTokenizer
t5/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (448 Bytes). View file
 
t5/__pycache__/copyt5_model.cpython-38.pyc ADDED
Binary file (28.5 kB). View file
 
t5/__pycache__/copyt5_utils.cpython-38.pyc ADDED
Binary file (6.18 kB). View file
 
t5/__pycache__/t5_model.cpython-38.pyc ADDED
Binary file (28.4 kB). View file
 
t5/__pycache__/t5_utils.cpython-38.pyc ADDED
Binary file (5.93 kB). View file
 
t5/config/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description: refer https://github.com/ThilinaRajapakse/simpletransformers
5
+ """
t5/config/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (265 Bytes). View file
 
t5/config/__pycache__/model_args.cpython-38.pyc ADDED
Binary file (15.7 kB). View file
 
t5/config/global_args.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description: refer https://github.com/ThilinaRajapakse/simpletransformers
5
+ """
6
+ import sys
7
+ from multiprocessing import cpu_count
8
+
9
+ global_args = {
10
+ "adam_epsilon": 1e-8,
11
+ "best_model_dir": "outputs/best_model",
12
+ "cache_dir": "cache_dir/",
13
+ "config": {},
14
+ "do_lower_case": False,
15
+ "early_stopping_consider_epochs": False,
16
+ "early_stopping_delta": 0,
17
+ "early_stopping_metric": "eval_loss",
18
+ "early_stopping_metric_minimize": True,
19
+ "early_stopping_patience": 3,
20
+ "encoding": None,
21
+ "eval_batch_size": 8,
22
+ "evaluate_during_training": False,
23
+ "evaluate_during_training_silent": True,
24
+ "evaluate_during_training_steps": 2000,
25
+ "evaluate_during_training_verbose": False,
26
+ "fp16": True,
27
+ "gradient_accumulation_steps": 1,
28
+ "learning_rate": 4e-5,
29
+ "local_rank": -1,
30
+ "logging_steps": 50,
31
+ "manual_seed": None,
32
+ "max_grad_norm": 1.0,
33
+ "max_seq_length": 128,
34
+ "multiprocessing_chunksize": 500,
35
+ "n_gpu": 1,
36
+ "no_cache": False,
37
+ "no_save": False,
38
+ "num_train_epochs": 1,
39
+ "output_dir": "outputs/",
40
+ "overwrite_output_dir": False,
41
+ "process_count": cpu_count() - 2 if cpu_count() > 2 else 1,
42
+ "reprocess_input_data": True,
43
+ "save_best_model": True,
44
+ "save_eval_checkpoints": True,
45
+ "save_model_every_epoch": True,
46
+ "save_steps": 2000,
47
+ "save_optimizer_and_scheduler": True,
48
+ "silent": False,
49
+ "tensorboard_dir": None,
50
+ "train_batch_size": 8,
51
+ "use_cached_eval_features": False,
52
+ "use_early_stopping": False,
53
+ "use_multiprocessing": False,
54
+ "wandb_kwargs": {},
55
+ "wandb_project": None,
56
+ "warmup_ratio": 0.06,
57
+ "warmup_steps": 0,
58
+ "weight_decay": 0,
59
+ }
60
+
61
+ if sys.platform == "win32":
62
+ global_args["process_count"] = min(global_args["process_count"], 61)
t5/config/model_args.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description: refer https://github.com/ThilinaRajapakse/simpletransformers
5
+ """
6
+ import json
7
+ import os
8
+ import sys
9
+ from dataclasses import asdict, dataclass, field
10
+ from multiprocessing import cpu_count
11
+ from typing import Optional
12
+
13
+ from loguru import logger
14
+ from torch.utils.data import Dataset
15
+
16
+
17
+ def get_default_process_count():
18
+ process_count = cpu_count() - 2 if cpu_count() > 2 else 1
19
+ if sys.platform == "win32":
20
+ process_count = min(process_count, 61)
21
+
22
+ return process_count
23
+
24
+
25
+ def get_special_tokens():
26
+ return ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]
27
+
28
+
29
+ @dataclass
30
+ class ModelArgs:
31
+ adafactor_beta1: float = None
32
+ adafactor_clip_threshold: float = 1.0
33
+ adafactor_decay_rate: float = -0.8
34
+ adafactor_eps: tuple = field(default_factory=lambda: (1e-30, 1e-3))
35
+ adafactor_relative_step: bool = True
36
+ adafactor_scale_parameter: bool = True
37
+ adafactor_warmup_init: bool = True
38
+ adam_epsilon: float = 1e-8
39
+ best_model_dir: str = "outputs/best_model"
40
+ cache_dir: str = "cache_dir/"
41
+ config: dict = field(default_factory=dict)
42
+ cosine_schedule_num_cycles: float = 0.5
43
+ custom_layer_parameters: list = field(default_factory=list)
44
+ custom_parameter_groups: list = field(default_factory=list)
45
+ dataloader_num_workers: int = 0
46
+ do_lower_case: bool = False
47
+ dynamic_quantize: bool = False
48
+ early_stopping_consider_epochs: bool = False
49
+ early_stopping_delta: float = 0
50
+ early_stopping_metric: str = "eval_loss"
51
+ early_stopping_metric_minimize: bool = True
52
+ early_stopping_patience: int = 3
53
+ encoding: str = "utf-8"
54
+ eval_batch_size: int = 8
55
+ evaluate_during_training: bool = False
56
+ evaluate_during_training_silent: bool = True
57
+ evaluate_during_training_steps: int = 6000
58
+ evaluate_during_training_verbose: bool = False
59
+ evaluate_each_epoch: bool = True
60
+ fp16: bool = False
61
+ gradient_accumulation_steps: int = 1
62
+ learning_rate: float = 2e-5
63
+ local_rank: int = -1
64
+ logging_steps: int = 50
65
+ manual_seed: int = None
66
+ max_grad_norm: float = 1.0
67
+ max_seq_length: int = 128 # max length of input sequence
68
+ model_name: str = None
69
+ model_type: str = None
70
+ multiprocessing_chunksize: int = -1
71
+ n_gpu: int = 2
72
+ no_cache: bool = False
73
+ no_save: bool = False
74
+ not_saved_args: list = field(default_factory=list)
75
+ num_train_epochs: int = 1
76
+ optimizer: str = "AdamW"
77
+ output_dir: str = "outputs/"
78
+ overwrite_output_dir: bool = True
79
+ polynomial_decay_schedule_lr_end: float = 1e-7
80
+ polynomial_decay_schedule_power: float = 1.0
81
+ process_count: int = field(default_factory=get_default_process_count)
82
+ quantized_model: bool = False
83
+ reprocess_input_data: bool = False
84
+ save_best_model: bool = True
85
+ save_eval_checkpoints: bool = True
86
+ save_model_every_epoch: bool = False
87
+ save_optimizer_and_scheduler: bool = True
88
+ save_steps: int = 10000
89
+ scheduler: str = "linear_schedule_with_warmup"
90
+ silent: bool = False
91
+ skip_special_tokens: bool = True
92
+ tensorboard_dir: str = None
93
+ thread_count: int = None
94
+ tokenizer_name: str = None
95
+ tokenizer_type: str = None
96
+ train_batch_size: int = 8
97
+ train_custom_parameters_only: bool = False
98
+ use_cached_eval_features: bool = False
99
+ use_early_stopping: bool = False
100
+ use_hf_datasets: bool = False
101
+ use_multiprocessing: bool = True
102
+ use_multiprocessing_for_evaluation: bool = True
103
+ wandb_kwargs: dict = field(default_factory=dict)
104
+ wandb_project: str = None
105
+ warmup_ratio: float = 0.06
106
+ warmup_steps: int = 0
107
+ weight_decay: float = 0.0
108
+
109
+ def update_from_dict(self, new_values):
110
+ if isinstance(new_values, dict):
111
+ for key, value in new_values.items():
112
+ setattr(self, key, value)
113
+ else:
114
+ raise (TypeError(f"{new_values} is not a Python dict."))
115
+
116
+ def get_args_for_saving(self):
117
+ args_for_saving = {key: value for key, value in asdict(self).items() if key not in self.not_saved_args}
118
+ return args_for_saving
119
+
120
+ def save(self, output_dir):
121
+ os.makedirs(output_dir, exist_ok=True)
122
+ with open(os.path.join(output_dir, "model_args.json"), "w", encoding='utf-8') as f:
123
+ args_dict = self.get_args_for_saving()
124
+ if args_dict['dataset_class'] is not None and not isinstance(args_dict["dataset_class"], str):
125
+ args_dict['dataset_class'] = type(args_dict['dataset_class']).__name__
126
+ if args_dict["tokenizer_type"] is not None and not isinstance(args_dict["tokenizer_type"], str):
127
+ args_dict["tokenizer_type"] = type(args_dict["tokenizer_type"]).__name__
128
+ json.dump(args_dict, f)
129
+
130
+ def load(self, input_dir):
131
+ if input_dir:
132
+ model_args_file = os.path.join(input_dir, "model_args.json")
133
+ if os.path.isfile(model_args_file):
134
+ with open(model_args_file, "r", encoding='utf-8') as f:
135
+ model_args = json.load(f)
136
+ if model_args["dataset_class"]:
137
+ logger.warning(
138
+ "This model was trained using a custom dataset_class."
139
+ "This cannot be loaded automatically and must be specified in the model args"
140
+ "when loading the model."
141
+ )
142
+ self.update_from_dict(model_args)
143
+
144
+
145
+ @dataclass
146
+ class T5Args(ModelArgs):
147
+ """
148
+ Model args for a T5Model
149
+ """
150
+
151
+ model_class: str = "T5Model"
152
+ dataset_class: Dataset = None
153
+ do_sample: bool = False
154
+ early_stopping: bool = True
155
+ evaluate_generated_text: bool = False
156
+ length_penalty: float = 2.0
157
+ max_length: int = 180 # max length of the sequence to be generated
158
+ max_steps: int = -1
159
+ num_beams: int = 1
160
+ num_return_sequences: int = 1
161
+ preprocess_inputs: bool = True
162
+ repetition_penalty: float = 1.0
163
+ scheduler: str = "constant_schedule_with_warmup"
164
+ adafactor_relative_step: bool = False
165
+ adafactor_scale_parameter: bool = False
166
+ adafactor_warmup_init: bool = False
167
+ learning_rate: float = 5e-4
168
+ optimizer: str = "AdamW"
169
+ special_tokens_list: list = field(default_factory=list)
170
+ top_k: float = None
171
+ top_p: float = None
172
+ use_multiprocessed_decoding: bool = False
173
+
174
+
175
+ @dataclass
176
+ class CopyT5Args(ModelArgs):
177
+ """
178
+ Model args for a CopyT5Model
179
+ """
180
+
181
+ model_class: str = "CopyT5Model"
182
+ dataset_class: Dataset = None
183
+ do_sample: bool = False
184
+ early_stopping: bool = True
185
+ evaluate_generated_text: bool = False
186
+ length_penalty: float = 2.0
187
+ max_length: int = 128 # max length of the sequence to be generated
188
+ max_steps: int = -1
189
+ num_beams: int = 3
190
+ num_return_sequences: int = 1
191
+ preprocess_inputs: bool = True
192
+ repetition_penalty: float = 1.0
193
+ scheduler: str = "linear_schedule_with_warmup"
194
+ adafactor_relative_step: bool = False
195
+ adafactor_scale_parameter: bool = False
196
+ adafactor_warmup_init: bool = False
197
+ learning_rate: float = 1e-3
198
+ optimizer: str = "AdamW"
199
+ special_tokens_list: list = field(default_factory=list)
200
+ top_k: float = None
201
+ top_p: float = None
202
+ use_multiprocessed_decoding: bool = False
203
+
204
+
205
+ @dataclass
206
+ class LanguageModelingArgs(ModelArgs):
207
+ """
208
+ Model args for a LanguageModelingModel
209
+ """
210
+
211
+ model_class: str = "LanguageModelingModel"
212
+ block_size: int = -1
213
+ config_name: str = None
214
+ dataset_class: Dataset = None
215
+ dataset_type: str = "None"
216
+ discriminator_config: dict = field(default_factory=dict)
217
+ discriminator_loss_weight: float = 50.0
218
+ generator_config: dict = field(default_factory=dict)
219
+ max_steps: int = -1
220
+ min_frequency: int = 2
221
+ mlm: bool = True
222
+ mlm_probability: float = 0.15
223
+ sliding_window: bool = False
224
+ special_tokens: list = field(default_factory=get_special_tokens)
225
+ stride: float = 0.8
226
+ tie_generator_and_discriminator_embeddings: bool = True
227
+ tokenizer_name: str = None
228
+ vocab_size: int = None
229
+ clean_text: bool = True
230
+ handle_chinese_chars: bool = True
231
+ special_tokens_list: list = field(default_factory=list)
232
+ strip_accents: bool = True
233
+ local_rank: int = -1
234
+
235
+
236
+ @dataclass
237
+ class Seq2SeqArgs(ModelArgs):
238
+ """
239
+ Model args for a Seq2SeqModel
240
+ """
241
+
242
+ model_class: str = "Seq2SeqModel"
243
+ base_marian_model_name: str = None
244
+ dataset_class: Dataset = None
245
+ do_sample: bool = False
246
+ early_stopping: bool = True
247
+ evaluate_generated_text: bool = False
248
+ faiss_d: int = 768
249
+ faiss_m: int = 128
250
+ length_penalty: float = 2.0
251
+ max_length: int = 128 # max length of the sequence to be generated
252
+ max_steps: int = -1
253
+ num_beams: int = 1
254
+ num_return_sequences: int = 1
255
+ rag_embed_batch_size: int = 16
256
+ repetition_penalty: float = 1.0
257
+ top_k: float = None
258
+ top_p: float = None
259
+ use_multiprocessed_decoding: bool = False
260
+ save_knowledge_dataset: bool = True
261
+ save_knowledge_dataset_with_checkpoints: bool = False
262
+ split_text_character: str = " "
263
+ split_text_n: int = 100
264
+ src_lang: str = "en_XX"
265
+ tgt_lang: str = "ro_RO"
266
+
267
+
268
+ @dataclass
269
+ class LanguageGenerationArgs(ModelArgs):
270
+ """
271
+ Model args for a LanguageGenerationModel
272
+ """
273
+
274
+ model_class: str = "LanguageGenerationModel"
275
+ do_sample: bool = True
276
+ early_stopping: bool = True
277
+ evaluate_generated_text: bool = False
278
+ length_penalty: float = 2.0
279
+ max_length: int = 128 # max length of the sequence to be generated
280
+ max_steps: int = -1
281
+ num_beams: int = 1
282
+ num_return_sequences: int = 1
283
+ repetition_penalty: float = 1.0
284
+ top_k: float = 50
285
+ top_p: float = 0.95
286
+ prompt: str = ""
287
+ stop_token: str = None
288
+ temperature: float = 1.0
289
+ padding_text: str = ""
290
+ xlm_language: str = ""
291
+ config_name: str = None
292
+ tokenizer_name: str = None
293
+ special_tokens_list: list = field(default_factory=list)
294
+
295
+
296
+ @dataclass
297
+ class SongNetArgs(LanguageModelingArgs):
298
+ """
299
+ Model args for a SongNetModel
300
+ """
301
+
302
+ model_class: str = "SongNetModel"
303
+ dataset_class: Dataset = None
304
+ do_sample: bool = False
305
+ early_stopping: bool = True
306
+ evaluate_generated_text: bool = False
307
+ length_penalty: float = 2.0
308
+ max_length: int = 128
309
+ min_length: int = 10
310
+ max_steps: int = -1
311
+ num_beams: int = 3
312
+ num_return_sequences: int = 1
313
+ repetition_penalty: float = 1.0
314
+ scheduler: str = None
315
+ adafactor_relative_step: bool = False
316
+ adafactor_scale_parameter: bool = False
317
+ adafactor_warmup_init: bool = False
318
+ learning_rate: float = 1e-3
319
+ early_stopping_metric: str = "eval_ppl"
320
+ special_tokens_list: list = field(default_factory=list)
321
+ save_eval_checkpoints: bool = False
322
+ skip_special_tokens: bool = False
323
+ k: int = 16
324
+ use_multiprocessed_decoding: bool = False
325
+ embed_dim: int = 768
326
+ ff_embed_dim: int = 3072
327
+ num_heads: int = 12
328
+ num_layers: int = 12
329
+ dropout: float = 0.2
330
+ warmup_ratio: float = 0.05
331
+ weight_decay: float = 0.0
332
+ smoothing_factor: float = 0.1
333
+
334
+
335
+ @dataclass
336
+ class ChatGlmArgs(ModelArgs):
337
+ """
338
+ Model args for a ChatGLMModel
339
+ """
340
+
341
+ model_class: str = "ChatGlmArgs"
342
+ dataset_class: Dataset = None
343
+ learning_rate: float = 2e-5
344
+ fp16: bool = True
345
+ bf16: bool = False
346
+ int8: bool = False
347
+ int4: bool = False
348
+ debug: bool = False
349
+ max_seq_length: int = 256 # max length of input sequence
350
+ max_length = 384 # max length of the sequence to be generated
351
+ do_sample: bool = True
352
+ early_stopping: bool = True
353
+ is_train_on_prompt: bool = False # if compute loss with prompt labels
354
+ evaluate_generated_text: bool = True
355
+ report_to = "tensorboard"
356
+ optimizer: str = "adamw_torch"
357
+ save_strategy: str = "steps"
358
+ evaluation_strategy: str = "no"
359
+ eval_steps: int = 50
360
+ save_steps: int = 400
361
+ max_eval_samples: int = 20
362
+ length_penalty: float = 2.0
363
+ num_beams: int = 4
364
+ num_return_sequences: int = 1
365
+ repetition_penalty: float = 1.0
366
+ temperature: float = 0.1
367
+ special_tokens_list: list = field(default_factory=list)
368
+ top_k: float = 40
369
+ top_p: float = 0.75
370
+ model_name_or_path: Optional[str] = field(default="THUDM/chatglm-6b")
371
+ use_peft: bool = True
372
+ peft_type: str = "LORA"
373
+ peft_bin_name: str = "adapter_model.bin"
374
+ lora_r: int = 8
375
+ lora_alpha = 32
376
+ lora_dropout = 0.05
377
+ lora_target_modules = ["all"] # ["all"] or ["query_key_value"]
378
+ lora_bias = "none"
379
+ adalora_init_r: int = 12
380
+ adalora_tinit: int = 200
381
+ adalora_tfinal: int = 1000
382
+ adalora_delta_t: int = 10
383
+ lora_beta: float = 0.85
384
+ num_virtual_tokens: int = 20
385
+ prompt_encoder_hidden_size: int = 128
386
+ num_train_epochs = 1
387
+ max_steps = -1
388
+ per_device_train_batch_size = 2
389
+ eval_batch_size: int = 4
390
+ gradient_accumulation_steps = 1
391
+ gradient_checkpointing: bool = True
392
+ torch_compile: bool = False
393
+ save_total_limit = 10
394
+ remove_unused_columns = False
395
+ logging_steps = 50
396
+ resume_from_checkpoint: str = None
397
+ qlora: bool = False
398
+
399
+
400
+ @dataclass
401
+ class GptArgs(ModelArgs):
402
+ """
403
+ Model args for a GptModel
404
+ """
405
+
406
+ model_class: str = "GptArgs"
407
+ dataset_class: Dataset = None
408
+ learning_rate: float = 2e-5
409
+ fp16: bool = True
410
+ bf16: bool = False
411
+ int8: bool = False
412
+ int4: bool = False
413
+ debug: bool = False
414
+ max_seq_length: int = 256 # max length of input sequence
415
+ max_length = 256 # max length of the sequence to be generated
416
+ do_sample: bool = True
417
+ early_stopping: bool = True
418
+ evaluate_generated_text: bool = True
419
+ is_train_on_prompt: bool = False # if compute loss with prompt labels
420
+ warmup_steps: int = 50
421
+ report_to = "tensorboard"
422
+ optimizer: str = "adamw_torch"
423
+ save_strategy: str = "steps"
424
+ eval_steps: int = 200
425
+ save_steps: int = 400
426
+ pad_to_multiple_of: int = 8
427
+ max_eval_samples: int = 20
428
+ length_penalty: float = 2.0
429
+ num_beams: int = 1
430
+ num_return_sequences: int = 1
431
+ repetition_penalty: float = 1.3
432
+ temperature: float = 0.4
433
+ special_tokens_list: list = field(default_factory=list)
434
+ top_k: float = 40
435
+ top_p: float = 0.9
436
+ model_name_or_path: Optional[str] = field(default="shibing624/chinese-alpaca-plus-7b-hf")
437
+ use_peft: bool = True
438
+ peft_type: str = "LORA"
439
+ peft_bin_name: str = "adapter_model.bin"
440
+ lora_r: int = 8
441
+ lora_alpha = 16
442
+ lora_dropout = 0.05
443
+ lora_target_modules = ["all"] # ["all"] or ["k_proj"]
444
+ lora_bias = "none"
445
+ adalora_init_r: int = 12
446
+ adalora_tinit: int = 200
447
+ adalora_tfinal: int = 1000
448
+ adalora_delta_t: int = 10
449
+ lora_beta: float = 0.85
450
+ num_virtual_tokens: int = 20
451
+ prompt_encoder_hidden_size: int = 128
452
+ num_train_epochs = 3
453
+ max_steps = -1
454
+ per_device_train_batch_size = 2
455
+ eval_batch_size: int = 4
456
+ gradient_accumulation_steps = 1
457
+ save_total_limit = 10
458
+ remove_unused_columns = False
459
+ logging_steps = 50
460
+ resume_from_checkpoint: str = None
461
+ gradient_checkpointing: bool = True
462
+ torch_compile: bool = False
463
+ trust_remote_code: bool = True
464
+ qlora: bool = False
t5/t5_model.py ADDED
@@ -0,0 +1,1256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description: refer https://github.com/ThilinaRajapakse/simpletransformers
5
+ """
6
+
7
+ import math
8
+ import os
9
+ import random
10
+ import warnings
11
+ from dataclasses import asdict
12
+ from multiprocessing import Pool
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+ import torch
17
+ from loguru import logger
18
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
19
+ from torch.utils.tensorboard import SummaryWriter
20
+ from tqdm.auto import tqdm, trange
21
+ from transformers import ByT5Tokenizer
22
+ from transformers import MT5Config, MT5ForConditionalGeneration
23
+ from transformers import T5Config, T5ForConditionalGeneration, T5Tokenizer, TextStreamer
24
+ from transformers.optimization import AdamW, Adafactor
25
+ from transformers.optimization import (
26
+ get_constant_schedule,
27
+ get_constant_schedule_with_warmup,
28
+ get_linear_schedule_with_warmup,
29
+ get_cosine_schedule_with_warmup,
30
+ get_cosine_with_hard_restarts_schedule_with_warmup,
31
+ get_polynomial_decay_schedule_with_warmup,
32
+ )
33
+
34
+ from t5.config.model_args import T5Args
35
+ from t5.t5_utils import T5Dataset, load_hf_dataset
36
+
37
+ try:
38
+ import wandb
39
+
40
+ wandb_available = True
41
+ except ImportError:
42
+ wandb_available = False
43
+
44
+ has_cuda = torch.cuda.is_available()
45
+ os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
46
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
47
+
48
+
49
+ def chunks(lst, n):
50
+ """Yield successive n-sized chunks from lst."""
51
+ for i in range(0, len(lst), n):
52
+ yield lst[i: i + n]
53
+
54
+
55
+ MODEL_CLASSES = {
56
+ "t5": (T5Config, T5ForConditionalGeneration),
57
+ "mt5": (MT5Config, MT5ForConditionalGeneration),
58
+ "byt5": (T5Config, T5ForConditionalGeneration),
59
+ }
60
+
61
+
62
+ class T5Model:
63
+ def __init__(
64
+ self,
65
+ model_type,
66
+ model_name,
67
+ args=None,
68
+ tokenizer=None,
69
+ use_cuda=has_cuda,
70
+ cuda_device=-1,
71
+ evaluate=False,
72
+ **kwargs,
73
+ ):
74
+
75
+ """
76
+ Initializes a T5Model model.
77
+
78
+ Args:
79
+ model_type: The type of model (t5, mt5, byt5)
80
+ model_name: The exact architecture and trained weights to use. This may be a Hugging Face Transformers compatible pre-trained model, a community model, or the path to a directory containing model files.
81
+ args (optional): Default args will be used if this parameter is not provided. If provided, it should be a dict containing the args that should be changed in the default args.
82
+ use_cuda (optional): Use GPU if available. Setting to False will force model to use CPU only.
83
+ cuda_device (optional): Specific GPU that should be used. Will use the first available GPU by default.
84
+ **kwargs (optional): For providing proxies, force_download, resume_download, cache_dir and other options specific to the 'from_pretrained' implementation where this will be supplied.
85
+ """ # noqa: ignore flake8"
86
+
87
+ self.args = self._load_model_args(model_name)
88
+
89
+ if isinstance(args, dict):
90
+ self.args.update_from_dict(args)
91
+ elif isinstance(args, T5Args):
92
+ self.args = args
93
+
94
+ self.is_sweeping = False
95
+
96
+ if self.args.manual_seed:
97
+ random.seed(self.args.manual_seed)
98
+ np.random.seed(self.args.manual_seed)
99
+ torch.manual_seed(self.args.manual_seed)
100
+ if self.args.n_gpu > 0:
101
+ torch.cuda.manual_seed_all(self.args.manual_seed)
102
+
103
+ if use_cuda:
104
+ if torch.cuda.is_available():
105
+ if cuda_device == -1:
106
+ self.device = torch.device("cuda")
107
+ else:
108
+ self.device = torch.device(f"cuda:{cuda_device}")
109
+ else:
110
+ raise ValueError(
111
+ "'use_cuda' set to True when cuda is unavailable."
112
+ "Make sure CUDA is available or set `use_cuda=False`."
113
+ )
114
+ else:
115
+ if torch.backends.mps.is_available():
116
+ self.device = torch.device("mps")
117
+ else:
118
+ self.device = "cpu"
119
+ logger.debug(f"Device: {self.device}")
120
+
121
+ self.results = {}
122
+
123
+ config_class, model_class = MODEL_CLASSES[model_type]
124
+
125
+ if model_name is None:
126
+ self.config = self.args.config
127
+ self.model = model_class(config=self.config)
128
+ else:
129
+ self.config = config_class.from_pretrained(model_name, **self.args.config)
130
+ self.model = model_class.from_pretrained(model_name, config=self.config)
131
+
132
+ if isinstance(tokenizer, T5Tokenizer):
133
+ self.tokenizer = tokenizer
134
+ self.model.resize_token_embeddings(len(self.tokenizer))
135
+ elif model_type == "byt5":
136
+ self.tokenizer = ByT5Tokenizer.from_pretrained(model_name, truncate=True)
137
+ else:
138
+ self.tokenizer = T5Tokenizer.from_pretrained(model_name, truncate=True)
139
+ print(len(self.tokenizer))
140
+ if not evaluate:
141
+ with open('./data/字音混淆集_s13.txt', 'r', encoding='utf-8') as confusion:
142
+ n = 0
143
+ for line in confusion.readlines()+[str(chr(c+65248)) for c in range(33, 127)]:
144
+ token = line.split(' ')[0]
145
+ n+=1
146
+ self.tokenizer.add_tokens([token])
147
+ with open('./data/字音混淆集.txt', 'r', encoding='utf-8') as confusion:
148
+ for line in confusion.readlines():
149
+ token = line.split(' ')[0]
150
+ n+=1
151
+ self.tokenizer.add_tokens([token])
152
+ with open('./data/wordtest4.txt', 'r', encoding='utf-8') as confusion:
153
+ for line in confusion.readlines():
154
+ token = line.split(',')[0]
155
+ n+=1
156
+ self.tokenizer.add_tokens([token])
157
+
158
+ with open('./data/vocab.txt', 'r', encoding='utf-8') as confusion:
159
+ for line in confusion.readlines():
160
+ n+=1
161
+ self.tokenizer.add_tokens([line.replace('\n', '')])
162
+
163
+ print(n)
164
+ self.streamer = TextStreamer(self.tokenizer)
165
+ print(len(self.tokenizer))
166
+ self.model.resize_token_embeddings(len(self.tokenizer))
167
+
168
+ if self.args.dynamic_quantize:
169
+ self.model = torch.quantization.quantize_dynamic(
170
+ self.model, {torch.nn.Linear}, dtype=torch.qint8
171
+ )
172
+
173
+ if not use_cuda:
174
+ self.args.fp16 = False
175
+
176
+ if self.args.special_tokens_list:
177
+ self.tokenizer.add_tokens(
178
+ self.args.special_tokens_list, special_tokens=True
179
+ )
180
+ self.model.resize_token_embeddings(len(self.tokenizer))
181
+
182
+ self.args.model_type = model_type
183
+ if model_name is None:
184
+ self.args.model_name = "T5_from_scratch"
185
+ else:
186
+ self.args.model_name = model_name
187
+
188
+ if self.args.wandb_project and not wandb_available:
189
+ warnings.warn(
190
+ "wandb_project specified but wandb is not available. Wandb disabled."
191
+ )
192
+ self.args.wandb_project = None
193
+
194
+ def train_model(
195
+ self,
196
+ train_data,
197
+ output_dir=None,
198
+ show_running_loss=True,
199
+ args=None,
200
+ eval_data=None,
201
+ verbose=True,
202
+ **kwargs,
203
+ ):
204
+ """
205
+ Trains the model using 'train_data'
206
+
207
+ Args:
208
+ train_data: Pandas DataFrame containing the 3 columns - `prefix`, `input_text`, `target_text`.
209
+ - `prefix`: A string indicating the task to perform. (E.g. `"question"`, `"stsb"`)
210
+ - `input_text`: The input text sequence. `prefix` is automatically prepended to form the full input. (<prefix>: <input_text>)
211
+ - `target_text`: The target sequence
212
+ output_dir: The directory where model files will be saved. If not given, self.args.output_dir will be used.
213
+ show_running_loss (optional): Set to False to prevent running loss from being printed to console. Defaults to True.
214
+ args (optional): Optional changes to the args dict of the model. Any changes made will persist for the model.
215
+ eval_data (optional): A DataFrame against which evaluation will be performed when evaluate_during_training is enabled. Is required if evaluate_during_training is enabled.
216
+ **kwargs: Additional metrics that should be used. Pass in the metrics as keyword arguments (name of metric: function to use).
217
+ A metric function should take in two parameters. The first parameter will be the true labels, and the second parameter will be the predictions. Both inputs
218
+ will be lists of strings. Note that this will slow down training significantly as the predicted sequences need to be generated.
219
+
220
+ Returns:
221
+ global_step: Number of global steps trained
222
+ training_details: Average training loss if evaluate_during_training is False or full training progress scores if evaluate_during_training is True
223
+ """ # noqa: ignore flake8"
224
+
225
+ if args:
226
+ self.args.update_from_dict(args)
227
+ if self.args.evaluate_during_training and eval_data is None:
228
+ raise ValueError(
229
+ "evaluate_during_training is enabled but eval_data is not specified."
230
+ " Pass eval_data to model.train_model() if using evaluate_during_training."
231
+ )
232
+
233
+ if not output_dir:
234
+ output_dir = self.args.output_dir
235
+
236
+ if (
237
+ os.path.exists(output_dir)
238
+ and os.listdir(output_dir)
239
+ and not self.args.overwrite_output_dir
240
+ ):
241
+ raise ValueError(
242
+ "Output directory ({}) already exists and is not empty."
243
+ " Set args.overwrite_output_dir = True to overcome.".format(output_dir)
244
+ )
245
+
246
+ self._move_model_to_device()
247
+
248
+ train_dataset = self.load_and_cache_examples(train_data, verbose=verbose)
249
+
250
+ os.makedirs(output_dir, exist_ok=True)
251
+
252
+ global_step, training_details = self.train(
253
+ train_dataset,
254
+ output_dir,
255
+ show_running_loss=show_running_loss,
256
+ eval_data=eval_data,
257
+ verbose=verbose,
258
+ **kwargs,
259
+ )
260
+
261
+ self.save_model(model=self.model)
262
+
263
+ if verbose:
264
+ logger.info(
265
+ " Training of {} model complete. Saved to {}.".format(
266
+ self.args.model_name, output_dir
267
+ )
268
+ )
269
+
270
+ return global_step, training_details
271
+
272
+ def train(
273
+ self,
274
+ train_dataset,
275
+ output_dir,
276
+ show_running_loss=True,
277
+ eval_data=None,
278
+ verbose=True,
279
+ **kwargs,
280
+ ):
281
+ """
282
+ Trains the model on train_dataset.
283
+
284
+ Utility function to be used by the train_model() method. Not intended to be used directly.
285
+ """
286
+
287
+ model = self.model
288
+ args = self.args
289
+ device = self.device
290
+
291
+ tb_writer = SummaryWriter(log_dir=args.tensorboard_dir)
292
+ train_sampler = RandomSampler(train_dataset)
293
+ train_dataloader = DataLoader(
294
+ train_dataset,
295
+ sampler=train_sampler,
296
+ batch_size=args.train_batch_size,
297
+ num_workers=self.args.dataloader_num_workers,
298
+ )
299
+
300
+ if args.max_steps > 0:
301
+ t_total = args.max_steps
302
+ args.num_train_epochs = (
303
+ args.max_steps
304
+ // (len(train_dataloader) // args.gradient_accumulation_steps)
305
+ + 1
306
+ )
307
+ else:
308
+ t_total = (
309
+ len(train_dataloader)
310
+ // args.gradient_accumulation_steps
311
+ * args.num_train_epochs
312
+ )
313
+
314
+ no_decay = ["bias", "LayerNorm.weight"]
315
+
316
+ optimizer_grouped_parameters = []
317
+ custom_parameter_names = set()
318
+ for group in self.args.custom_parameter_groups:
319
+ params = group.pop("params")
320
+ custom_parameter_names.update(params)
321
+ param_group = {**group}
322
+ param_group["params"] = [
323
+ p for n, p in model.named_parameters() if n in params
324
+ ]
325
+ optimizer_grouped_parameters.append(param_group)
326
+
327
+ for group in self.args.custom_layer_parameters:
328
+ layer_number = group.pop("layer")
329
+ layer = f"layer.{layer_number}."
330
+ group_d = {**group}
331
+ group_nd = {**group}
332
+ group_nd["weight_decay"] = 0.0
333
+ params_d = []
334
+ params_nd = []
335
+ for n, p in model.named_parameters():
336
+ if n not in custom_parameter_names and layer in n:
337
+ if any(nd in n for nd in no_decay):
338
+ params_nd.append(p)
339
+ else:
340
+ params_d.append(p)
341
+ custom_parameter_names.add(n)
342
+ group_d["params"] = params_d
343
+ group_nd["params"] = params_nd
344
+
345
+ optimizer_grouped_parameters.append(group_d)
346
+ optimizer_grouped_parameters.append(group_nd)
347
+
348
+ if not self.args.train_custom_parameters_only:
349
+ optimizer_grouped_parameters.extend(
350
+ [
351
+ {
352
+ "params": [
353
+ p
354
+ for n, p in model.named_parameters()
355
+ if n not in custom_parameter_names
356
+ and not any(nd in n for nd in no_decay)
357
+ ],
358
+ "weight_decay": args.weight_decay,
359
+ },
360
+ {
361
+ "params": [
362
+ p
363
+ for n, p in model.named_parameters()
364
+ if n not in custom_parameter_names
365
+ and any(nd in n for nd in no_decay)
366
+ ],
367
+ "weight_decay": 0.0,
368
+ },
369
+ ]
370
+ )
371
+
372
+ warmup_steps = math.ceil(t_total * args.warmup_ratio)
373
+ args.warmup_steps = (
374
+ warmup_steps if args.warmup_steps == 0 else args.warmup_steps
375
+ )
376
+
377
+ if args.optimizer == "AdamW":
378
+ optimizer = AdamW(
379
+ optimizer_grouped_parameters,
380
+ lr=args.learning_rate,
381
+ eps=args.adam_epsilon,
382
+ )
383
+ elif args.optimizer == "Adafactor":
384
+ optimizer = Adafactor(
385
+ optimizer_grouped_parameters,
386
+ lr=args.learning_rate,
387
+ eps=args.adafactor_eps,
388
+ clip_threshold=args.adafactor_clip_threshold,
389
+ decay_rate=args.adafactor_decay_rate,
390
+ beta1=args.adafactor_beta1,
391
+ weight_decay=args.weight_decay,
392
+ scale_parameter=args.adafactor_scale_parameter,
393
+ relative_step=args.adafactor_relative_step,
394
+ warmup_init=args.adafactor_warmup_init,
395
+ )
396
+
397
+ else:
398
+ raise ValueError(
399
+ "{} is not a valid optimizer class. Please use one of ('AdamW', 'Adafactor') instead.".format(
400
+ args.optimizer
401
+ )
402
+ )
403
+
404
+ if args.scheduler == "constant_schedule":
405
+ scheduler = get_constant_schedule(optimizer)
406
+
407
+ elif args.scheduler == "constant_schedule_with_warmup":
408
+ scheduler = get_constant_schedule_with_warmup(
409
+ optimizer, num_warmup_steps=args.warmup_steps
410
+ )
411
+
412
+ elif args.scheduler == "linear_schedule_with_warmup":
413
+ scheduler = get_linear_schedule_with_warmup(
414
+ optimizer,
415
+ num_warmup_steps=args.warmup_steps,
416
+ num_training_steps=t_total,
417
+ )
418
+
419
+ elif args.scheduler == "cosine_schedule_with_warmup":
420
+ scheduler = get_cosine_schedule_with_warmup(
421
+ optimizer,
422
+ num_warmup_steps=args.warmup_steps,
423
+ num_training_steps=t_total,
424
+ num_cycles=args.cosine_schedule_num_cycles,
425
+ )
426
+
427
+ elif args.scheduler == "cosine_with_hard_restarts_schedule_with_warmup":
428
+ scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
429
+ optimizer,
430
+ num_warmup_steps=args.warmup_steps,
431
+ num_training_steps=t_total,
432
+ num_cycles=args.cosine_schedule_num_cycles,
433
+ )
434
+
435
+ elif args.scheduler == "polynomial_decay_schedule_with_warmup":
436
+ scheduler = get_polynomial_decay_schedule_with_warmup(
437
+ optimizer,
438
+ num_warmup_steps=args.warmup_steps,
439
+ num_training_steps=t_total,
440
+ lr_end=args.polynomial_decay_schedule_lr_end,
441
+ power=args.polynomial_decay_schedule_power,
442
+ )
443
+
444
+ else:
445
+ raise ValueError("{} is not a valid scheduler.".format(args.scheduler))
446
+
447
+ if (
448
+ args.model_name
449
+ and os.path.isfile(os.path.join(args.model_name, "optimizer.pt"))
450
+ and os.path.isfile(os.path.join(args.model_name, "scheduler.pt"))
451
+ ):
452
+ # Load in optimizer and scheduler states
453
+ optimizer.load_state_dict(
454
+ torch.load(os.path.join(args.model_name, "optimizer.pt"))
455
+ )
456
+ scheduler.load_state_dict(
457
+ torch.load(os.path.join(args.model_name, "scheduler.pt"))
458
+ )
459
+
460
+ if args.n_gpu > 1:
461
+ model = torch.nn.DataParallel(model)
462
+
463
+ logger.info(" Training started")
464
+
465
+ global_step = 0
466
+ training_progress_scores = None
467
+ tr_loss, logging_loss = 0.0, 0.0
468
+ model.zero_grad()
469
+ train_iterator = trange(
470
+ int(args.num_train_epochs), desc="Epoch", disable=args.silent, mininterval=0
471
+ )
472
+ epoch_number = 0
473
+ best_eval_metric = None
474
+ early_stopping_counter = 0
475
+ steps_trained_in_current_epoch = 0
476
+ epochs_trained = 0
477
+
478
+ if args.model_name and os.path.exists(args.model_name):
479
+ try:
480
+ # set global_step to gobal_step of last saved checkpoint from model path
481
+ checkpoint_suffix = args.model_name.split("/")[-1].split("-")
482
+ if len(checkpoint_suffix) > 2:
483
+ checkpoint_suffix = checkpoint_suffix[1]
484
+ else:
485
+ checkpoint_suffix = checkpoint_suffix[-1]
486
+ global_step = int(checkpoint_suffix)
487
+ epochs_trained = global_step // (
488
+ len(train_dataloader) // args.gradient_accumulation_steps
489
+ )
490
+ steps_trained_in_current_epoch = global_step % (
491
+ len(train_dataloader) // args.gradient_accumulation_steps
492
+ )
493
+
494
+ logger.info(
495
+ " Continuing training from checkpoint, will skip to saved global_step"
496
+ )
497
+ logger.info(" Continuing training from epoch %d", epochs_trained)
498
+ logger.info(" Continuing training from global step %d", global_step)
499
+ logger.info(
500
+ " Will skip the first %d steps in the current epoch",
501
+ steps_trained_in_current_epoch,
502
+ )
503
+ except ValueError:
504
+ logger.info(" Starting fine-tuning.")
505
+
506
+ if args.evaluate_during_training:
507
+ training_progress_scores = self._create_training_progress_scores(**kwargs)
508
+
509
+ if args.wandb_project:
510
+ wandb.init(
511
+ project=args.wandb_project,
512
+ config={**asdict(args)},
513
+ **args.wandb_kwargs,
514
+ )
515
+ wandb.run._label(repo="textgen")
516
+ wandb.watch(self.model)
517
+ self.wandb_run_id = wandb.run.id
518
+
519
+ if args.fp16:
520
+ from torch.cuda import amp
521
+
522
+ scaler = amp.GradScaler()
523
+
524
+ for current_epoch in train_iterator:
525
+ model.train()
526
+ if epochs_trained > 0:
527
+ epochs_trained -= 1
528
+ continue
529
+ train_iterator.set_description(
530
+ f"Epoch {epoch_number + 1} of {args.num_train_epochs}"
531
+ )
532
+ batch_iterator = tqdm(
533
+ train_dataloader,
534
+ desc=f"Running Epoch {epoch_number} of {args.num_train_epochs}",
535
+ disable=args.silent,
536
+ mininterval=0,
537
+ )
538
+ for step, batch in enumerate(batch_iterator):
539
+ if steps_trained_in_current_epoch > 0:
540
+ steps_trained_in_current_epoch -= 1
541
+ continue
542
+
543
+ inputs = self._get_inputs_dict(batch)
544
+ if args.fp16:
545
+ with amp.autocast():
546
+ outputs = model(**inputs)
547
+ # model outputs are always tuple in pytorch-transformers (see doc)
548
+ loss = outputs[0]
549
+ else:
550
+ outputs = model(**inputs)
551
+ # model outputs are always tuple in pytorch-transformers (see doc)
552
+ loss = outputs[0]
553
+
554
+ if args.n_gpu > 1:
555
+ loss = (
556
+ loss.mean()
557
+ ) # mean() to average on multi-gpu parallel training
558
+
559
+ current_loss = loss.item()
560
+
561
+ if show_running_loss:
562
+ batch_iterator.set_description(
563
+ f"Epochs {epoch_number}/{args.num_train_epochs}. Running Loss: {current_loss:9.4f}"
564
+ )
565
+
566
+ if args.gradient_accumulation_steps > 1:
567
+ loss = loss / args.gradient_accumulation_steps
568
+
569
+ if args.fp16:
570
+ scaler.scale(loss).backward()
571
+ else:
572
+ loss.backward()
573
+
574
+ tr_loss += loss.item()
575
+ if (step + 1) % args.gradient_accumulation_steps == 0:
576
+ if args.fp16:
577
+ scaler.unscale_(optimizer)
578
+ if args.optimizer == "AdamW":
579
+ torch.nn.utils.clip_grad_norm_(
580
+ model.parameters(), args.max_grad_norm
581
+ )
582
+
583
+ if args.fp16:
584
+ scaler.step(optimizer)
585
+ scaler.update()
586
+ else:
587
+ optimizer.step()
588
+ scheduler.step() # Update learning rate schedule
589
+ model.zero_grad()
590
+ global_step += 1
591
+
592
+ if args.logging_steps > 0 and global_step % args.logging_steps == 0:
593
+ # Log metrics
594
+ tb_writer.add_scalar(
595
+ "lr", scheduler.get_last_lr()[0], global_step
596
+ )
597
+ tb_writer.add_scalar(
598
+ "loss",
599
+ (tr_loss - logging_loss) / args.logging_steps,
600
+ global_step,
601
+ )
602
+ logging_loss = tr_loss
603
+ if args.wandb_project or self.is_sweeping:
604
+ wandb.log(
605
+ {
606
+ "Training loss": current_loss,
607
+ "lr": scheduler.get_last_lr()[0],
608
+ "global_step": global_step,
609
+ }
610
+ )
611
+
612
+ if args.save_steps > 0 and global_step % args.save_steps == 0:
613
+ # Save model checkpoint
614
+ output_dir_current = os.path.join(
615
+ output_dir, "checkpoint-{}".format(global_step)
616
+ )
617
+
618
+ self.save_model(
619
+ output_dir_current, optimizer, scheduler, model=model
620
+ )
621
+
622
+ if args.evaluate_during_training and (
623
+ args.evaluate_during_training_steps > 0
624
+ and global_step % args.evaluate_during_training_steps == 0
625
+ ):
626
+ # Only evaluate when single GPU otherwise metrics may not average well
627
+ results = self.eval_model(
628
+ eval_data,
629
+ verbose=verbose and args.evaluate_during_training_verbose,
630
+ silent=args.evaluate_during_training_silent,
631
+ **kwargs,
632
+ )
633
+ for key, value in results.items():
634
+ try:
635
+ tb_writer.add_scalar(
636
+ "eval_{}".format(key), value, global_step
637
+ )
638
+ except (NotImplementedError, AssertionError):
639
+ pass
640
+
641
+ output_dir_current = os.path.join(
642
+ output_dir, "checkpoint-{}".format(global_step)
643
+ )
644
+
645
+ if args.save_eval_checkpoints:
646
+ self.save_model(
647
+ output_dir_current,
648
+ optimizer,
649
+ scheduler,
650
+ model=model,
651
+ results=results,
652
+ )
653
+
654
+ training_progress_scores["global_step"].append(global_step)
655
+ training_progress_scores["train_loss"].append(current_loss)
656
+ for key in results:
657
+ training_progress_scores[key].append(results[key])
658
+ report = pd.DataFrame(training_progress_scores)
659
+ report.to_csv(
660
+ os.path.join(
661
+ args.output_dir, "training_progress_scores.csv"
662
+ ),
663
+ index=False,
664
+ )
665
+
666
+ if args.wandb_project or self.is_sweeping:
667
+ wandb.log(self._get_last_metrics(training_progress_scores))
668
+
669
+ if not best_eval_metric:
670
+ best_eval_metric = results[args.early_stopping_metric]
671
+ self.save_model(
672
+ args.best_model_dir,
673
+ optimizer,
674
+ scheduler,
675
+ model=model,
676
+ results=results,
677
+ )
678
+ if best_eval_metric and args.early_stopping_metric_minimize:
679
+ if (
680
+ results[args.early_stopping_metric] - best_eval_metric
681
+ < args.early_stopping_delta
682
+ ):
683
+ best_eval_metric = results[args.early_stopping_metric]
684
+ self.save_model(
685
+ args.best_model_dir,
686
+ optimizer,
687
+ scheduler,
688
+ model=model,
689
+ results=results,
690
+ )
691
+ early_stopping_counter = 0
692
+ else:
693
+ if args.use_early_stopping:
694
+ if (
695
+ early_stopping_counter
696
+ < args.early_stopping_patience
697
+ ):
698
+ early_stopping_counter += 1
699
+ if verbose:
700
+ logger.info(
701
+ f" No improvement in {args.early_stopping_metric}"
702
+ )
703
+ logger.info(
704
+ f" Current step: {early_stopping_counter}"
705
+ )
706
+ logger.info(
707
+ f" Early stopping patience: {args.early_stopping_patience}"
708
+ )
709
+ else:
710
+ if verbose:
711
+ logger.info(
712
+ f" Patience of {args.early_stopping_patience} steps reached"
713
+ )
714
+ logger.info(" Training terminated.")
715
+ train_iterator.close()
716
+ return (
717
+ global_step,
718
+ tr_loss / global_step
719
+ if not self.args.evaluate_during_training
720
+ else training_progress_scores,
721
+ )
722
+ else:
723
+ if (
724
+ results[args.early_stopping_metric] - best_eval_metric
725
+ > args.early_stopping_delta
726
+ ):
727
+ best_eval_metric = results[args.early_stopping_metric]
728
+ self.save_model(
729
+ args.best_model_dir,
730
+ optimizer,
731
+ scheduler,
732
+ model=model,
733
+ results=results,
734
+ )
735
+ early_stopping_counter = 0
736
+ else:
737
+ if args.use_early_stopping:
738
+ if (
739
+ early_stopping_counter
740
+ < args.early_stopping_patience
741
+ ):
742
+ early_stopping_counter += 1
743
+ if verbose:
744
+ logger.info(
745
+ f" No improvement in {args.early_stopping_metric}"
746
+ )
747
+ logger.info(
748
+ f" Current step: {early_stopping_counter}"
749
+ )
750
+ logger.info(
751
+ f" Early stopping patience: {args.early_stopping_patience}"
752
+ )
753
+ else:
754
+ if verbose:
755
+ logger.info(
756
+ f" Patience of {args.early_stopping_patience} steps reached"
757
+ )
758
+ logger.info(" Training terminated.")
759
+ train_iterator.close()
760
+ return (
761
+ global_step,
762
+ tr_loss / global_step
763
+ if not self.args.evaluate_during_training
764
+ else training_progress_scores,
765
+ )
766
+ model.train()
767
+
768
+ epoch_number += 1
769
+ output_dir_current = os.path.join(
770
+ output_dir, "checkpoint-{}-epoch-{}".format(global_step, epoch_number)
771
+ )
772
+
773
+ if args.save_model_every_epoch:
774
+ self.save_model(output_dir_current, optimizer, scheduler, model=model)
775
+
776
+ if args.evaluate_during_training and args.evaluate_each_epoch:
777
+ results = self.eval_model(
778
+ eval_data,
779
+ verbose=verbose and args.evaluate_during_training_verbose,
780
+ silent=args.evaluate_during_training_silent,
781
+ **kwargs,
782
+ )
783
+
784
+ if args.save_eval_checkpoints:
785
+ self.save_model(
786
+ output_dir_current, optimizer, scheduler, results=results
787
+ )
788
+
789
+ training_progress_scores["global_step"].append(global_step)
790
+ training_progress_scores["train_loss"].append(current_loss)
791
+ for key in results:
792
+ training_progress_scores[key].append(results[key])
793
+ report = pd.DataFrame(training_progress_scores)
794
+ report.to_csv(
795
+ os.path.join(args.output_dir, "training_progress_scores.csv"),
796
+ index=False,
797
+ )
798
+
799
+ if args.wandb_project or self.is_sweeping:
800
+ wandb.log(self._get_last_metrics(training_progress_scores))
801
+
802
+ if not best_eval_metric:
803
+ best_eval_metric = results[args.early_stopping_metric]
804
+ self.save_model(
805
+ args.best_model_dir,
806
+ optimizer,
807
+ scheduler,
808
+ model=model,
809
+ results=results,
810
+ )
811
+ if best_eval_metric and args.early_stopping_metric_minimize:
812
+ if (
813
+ results[args.early_stopping_metric] - best_eval_metric
814
+ < args.early_stopping_delta
815
+ ):
816
+ best_eval_metric = results[args.early_stopping_metric]
817
+ self.save_model(
818
+ args.best_model_dir,
819
+ optimizer,
820
+ scheduler,
821
+ model=model,
822
+ results=results,
823
+ )
824
+ early_stopping_counter = 0
825
+ else:
826
+ if (
827
+ args.use_early_stopping
828
+ and args.early_stopping_consider_epochs
829
+ ):
830
+ if early_stopping_counter < args.early_stopping_patience:
831
+ early_stopping_counter += 1
832
+ if verbose:
833
+ logger.info(
834
+ f" No improvement in {args.early_stopping_metric}"
835
+ )
836
+ logger.info(
837
+ f" Current step: {early_stopping_counter}"
838
+ )
839
+ logger.info(
840
+ f" Early stopping patience: {args.early_stopping_patience}"
841
+ )
842
+ else:
843
+ if verbose:
844
+ logger.info(
845
+ f" Patience of {args.early_stopping_patience} steps reached"
846
+ )
847
+ logger.info(" Training terminated.")
848
+ train_iterator.close()
849
+ return (
850
+ global_step,
851
+ tr_loss / global_step
852
+ if not self.args.evaluate_during_training
853
+ else training_progress_scores,
854
+ )
855
+ else:
856
+ if (
857
+ results[args.early_stopping_metric] - best_eval_metric
858
+ > args.early_stopping_delta
859
+ ):
860
+ best_eval_metric = results[args.early_stopping_metric]
861
+ self.save_model(
862
+ args.best_model_dir,
863
+ optimizer,
864
+ scheduler,
865
+ model=model,
866
+ results=results,
867
+ )
868
+ early_stopping_counter = 0
869
+ else:
870
+ if (
871
+ args.use_early_stopping
872
+ and args.early_stopping_consider_epochs
873
+ ):
874
+ if early_stopping_counter < args.early_stopping_patience:
875
+ early_stopping_counter += 1
876
+ if verbose:
877
+ logger.info(
878
+ f" No improvement in {args.early_stopping_metric}"
879
+ )
880
+ logger.info(
881
+ f" Current step: {early_stopping_counter}"
882
+ )
883
+ logger.info(
884
+ f" Early stopping patience: {args.early_stopping_patience}"
885
+ )
886
+ else:
887
+ if verbose:
888
+ logger.info(
889
+ f" Patience of {args.early_stopping_patience} steps reached"
890
+ )
891
+ logger.info(" Training terminated.")
892
+ train_iterator.close()
893
+ return (
894
+ global_step,
895
+ tr_loss / global_step
896
+ if not self.args.evaluate_during_training
897
+ else training_progress_scores,
898
+ )
899
+
900
+ return (
901
+ global_step,
902
+ tr_loss / global_step
903
+ if not self.args.evaluate_during_training
904
+ else training_progress_scores,
905
+ )
906
+
907
+ def eval_model(
908
+ self, eval_data, output_dir=None, verbose=True, silent=False, **kwargs
909
+ ):
910
+ """
911
+ Evaluates the model on eval_data. Saves results to output_dir.
912
+
913
+ Args:
914
+ eval_data: Pandas DataFrame containing the 3 columns - `prefix`, `input_text`, `target_text`.
915
+ - `prefix`: A string indicating the task to perform. (E.g. `"question"`, `"stsb"`)
916
+ - `input_text`: The input text sequence. `prefix` is automatically prepended to form the full input. (<prefix>: <input_text>)
917
+ - `target_text`: The target sequence
918
+ output_dir: The directory where model files will be saved. If not given, self.args.output_dir will be used.
919
+ verbose: If verbose, results will be printed to the console on completion of evaluation.
920
+ silent: If silent, tqdm progress bars will be hidden.
921
+ **kwargs: Additional metrics that should be used. Pass in the metrics as keyword arguments (name of metric: function to use).
922
+ A metric function should take in two parameters. The first parameter will be the true labels, and the second parameter will be the predictions. Both inputs
923
+ will be lists of strings. Note that this will slow down evaluation significantly as the predicted sequences need to be generated.
924
+ Returns:
925
+ results: Dictionary containing evaluation results.
926
+ """ # noqa: ignore flake8"
927
+
928
+ if not output_dir:
929
+ output_dir = self.args.output_dir
930
+
931
+ self._move_model_to_device()
932
+
933
+ eval_dataset = self.load_and_cache_examples(
934
+ eval_data, evaluate=True, verbose=verbose, silent=silent
935
+ )
936
+ os.makedirs(output_dir, exist_ok=True)
937
+
938
+ result = self.evaluate(
939
+ eval_dataset, output_dir, verbose=verbose, silent=silent, **kwargs
940
+ )
941
+ self.results.update(result)
942
+
943
+ if self.args.evaluate_generated_text:
944
+ if self.args.preprocess_inputs:
945
+ to_predict = [
946
+ input_text
947
+ for prefix, input_text in zip(
948
+ eval_data["prefix"], eval_data["input_text"]
949
+ )
950
+ ]
951
+ else:
952
+ to_predict = [
953
+ prefix + input_text
954
+ for prefix, input_text in zip(
955
+ eval_data["prefix"], eval_data["input_text"]
956
+ )
957
+ ]
958
+ preds = self.predict(to_predict[:self.args.eval_batch_size*3])
959
+
960
+ result = self.compute_metrics(
961
+ eval_data["target_text"].tolist()[:self.args.eval_batch_size*3], preds, **kwargs
962
+ )
963
+ self.results.update(result)
964
+
965
+ if verbose:
966
+ logger.info(self.results)
967
+
968
+ return self.results
969
+
970
+ def evaluate(self, eval_dataset, output_dir, verbose=True, silent=False, **kwargs):
971
+ """
972
+ Evaluates the model on eval_dataset.
973
+
974
+ Utility function to be used by the eval_model() method. Not intended to be used directly.
975
+ """
976
+
977
+ model = self.model
978
+ args = self.args
979
+ eval_output_dir = output_dir
980
+ device = self.device
981
+
982
+ results = {}
983
+
984
+ eval_sampler = SequentialSampler(eval_dataset)
985
+ eval_dataloader = DataLoader(
986
+ eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size
987
+ )
988
+
989
+ if args.n_gpu > 1:
990
+ model = torch.nn.DataParallel(model)
991
+
992
+ eval_loss = 0.0
993
+ nb_eval_steps = 0
994
+ model.eval()
995
+
996
+ if self.args.fp16:
997
+ from torch.cuda import amp
998
+
999
+ for batch in tqdm(
1000
+ eval_dataloader, disable=args.silent or silent, desc="Running Evaluation"
1001
+ ):
1002
+ inputs = self._get_inputs_dict(batch)
1003
+ with torch.no_grad():
1004
+ if self.args.fp16:
1005
+ with amp.autocast():
1006
+ outputs = model(**inputs)
1007
+ loss = outputs[0]
1008
+ else:
1009
+ outputs = model(**inputs)
1010
+ loss = outputs[0]
1011
+ if self.args.n_gpu > 1:
1012
+ loss = loss.mean()
1013
+ eval_loss += loss.item()
1014
+ nb_eval_steps += 1
1015
+
1016
+ eval_loss = eval_loss / nb_eval_steps
1017
+
1018
+ results["eval_loss"] = eval_loss
1019
+
1020
+ output_eval_file = os.path.join(eval_output_dir, "eval_results.txt")
1021
+ with open(output_eval_file, "w") as writer:
1022
+ for key in sorted(results.keys()):
1023
+ writer.write("{} = {}\n".format(key, str(results[key])))
1024
+
1025
+ return results
1026
+
1027
+ def predict(self, to_predict, split_on_space=False):
1028
+ """
1029
+ Performs predictions on a list of text.
1030
+
1031
+ Args:
1032
+ to_predict: A python list of text (str) to be sent to the model for prediction. Note that the prefix should be prepended to the text.
1033
+ split_on_space (optional): If True, input is english string, if False, input is chinese string.
1034
+
1035
+ Returns:
1036
+ preds: A python list of the generated sequences.
1037
+ """ # noqa: ignore flake8"
1038
+
1039
+ self._move_model_to_device()
1040
+
1041
+ all_outputs = []
1042
+ # Batching
1043
+ for batch in tqdm(
1044
+ [
1045
+ to_predict[i: i + self.args.eval_batch_size]
1046
+ for i in range(0, len(to_predict), self.args.eval_batch_size)
1047
+ ],
1048
+ desc="Generating outputs",
1049
+ disable=self.args.silent,
1050
+ ):
1051
+ input_batch = self.tokenizer.prepare_seq2seq_batch(
1052
+ src_texts=batch,
1053
+ max_length=self.args.max_seq_length,
1054
+ padding="max_length",
1055
+ return_tensors="pt",
1056
+ truncation=True,
1057
+ )
1058
+ input_ids = input_batch["input_ids"]
1059
+ attention_mask = input_batch["attention_mask"]
1060
+
1061
+ input_ids = input_ids.to(self.device)
1062
+ attention_mask = attention_mask.to(self.device)
1063
+
1064
+ outputs = self.model.generate(
1065
+ input_ids=input_ids,
1066
+ attention_mask=attention_mask,
1067
+ num_beams=self.args.num_beams,
1068
+ max_length=self.args.max_length,
1069
+ length_penalty=self.args.length_penalty,
1070
+ early_stopping=self.args.early_stopping,
1071
+ repetition_penalty=self.args.repetition_penalty,
1072
+ do_sample=self.args.do_sample,
1073
+ top_k=self.args.top_k,
1074
+ top_p=self.args.top_p,
1075
+ num_return_sequences=self.args.num_return_sequences,
1076
+ #streamer=self.streamer,
1077
+ )
1078
+ all_outputs.extend(outputs.cpu().numpy())
1079
+
1080
+ if self.args.use_multiprocessed_decoding:
1081
+ self.model.to("cpu")
1082
+ with Pool(self.args.process_count) as p:
1083
+ if self.args.multiprocessing_chunksize == -1:
1084
+ chunksize = max(
1085
+ len(all_outputs) // (self.args.process_count * 2), 500
1086
+ )
1087
+ else:
1088
+ chunksize = self.args.multiprocessing_chunksize
1089
+ outputs = list(
1090
+ tqdm(
1091
+ p.imap(self._decode, all_outputs, chunksize=chunksize),
1092
+ total=len(all_outputs),
1093
+ desc="Decoding outputs",
1094
+ disable=self.args.silent,
1095
+ )
1096
+ )
1097
+ self._move_model_to_device()
1098
+ else:
1099
+ outputs = [
1100
+ self.tokenizer.decode(
1101
+ output_id,
1102
+ skip_special_tokens=self.args.skip_special_tokens,
1103
+ clean_up_tokenization_spaces=True,
1104
+ )
1105
+ for output_id in all_outputs
1106
+ ]
1107
+ if not split_on_space:
1108
+ outputs = [''.join(gen_text.split(' ')) for gen_text in outputs]
1109
+ if self.args.num_return_sequences > 1:
1110
+ return [
1111
+ outputs[i: i + self.args.num_return_sequences]
1112
+ for i in range(0, len(outputs), self.args.num_return_sequences)
1113
+ ]
1114
+ else:
1115
+ return outputs
1116
+
1117
+ def _decode(self, output_id):
1118
+ return self.tokenizer.decode(
1119
+ output_id,
1120
+ skip_special_tokens=self.args.skip_special_tokens,
1121
+ clean_up_tokenization_spaces=True,
1122
+ )
1123
+
1124
+ def compute_metrics(self, labels, preds, **kwargs):
1125
+ """
1126
+ Computes the evaluation metrics for the model predictions.
1127
+
1128
+ Args:
1129
+ labels: List of target sequences
1130
+ preds: List of model generated outputs
1131
+ **kwargs: Custom metrics that should be used. Pass in the metrics as keyword arguments (name of metric: function to use).
1132
+ A metric function should take in two parameters. The first parameter will be the true labels, and the second parameter will be the predictions. Both inputs
1133
+ will be lists of strings. Note that this will slow down evaluation significantly as the predicted sequences need to be generated.
1134
+
1135
+ Returns:
1136
+ result: Dictionary containing evaluation results.
1137
+ """ # noqa: ignore flake8"
1138
+ assert len(labels) == len(preds)
1139
+
1140
+ results = {}
1141
+ for metric, func in kwargs.items():
1142
+ results[metric] = func(labels, preds)
1143
+
1144
+ return results
1145
+
1146
+ def _move_model_to_device(self):
1147
+ self.model.to(self.device)
1148
+
1149
+ def _get_inputs_dict(self, batch):
1150
+ if self.args.use_hf_datasets:
1151
+ inputs = {**batch, "labels": batch["input_ids"]}
1152
+
1153
+ return {key: value.to(self.device) for key, value in inputs.items()}
1154
+ else:
1155
+ batch = tuple(t.to(self.device) for t in batch)
1156
+
1157
+ input_ids = batch[0]
1158
+ attention_mask = batch[1]
1159
+ labels = batch[2]
1160
+ labels[labels == self.tokenizer.pad_token_id] = -100
1161
+
1162
+ inputs = {
1163
+ "input_ids": input_ids,
1164
+ "attention_mask": attention_mask,
1165
+ "labels": labels,
1166
+ }
1167
+
1168
+ return inputs
1169
+
1170
+ def load_and_cache_examples(
1171
+ self, data, evaluate=False, no_cache=False, verbose=True, silent=False
1172
+ ):
1173
+ """
1174
+ Creates a T5Dataset from data.
1175
+
1176
+ Utility function for train() and eval() methods. Not intended to be used directly.
1177
+ """
1178
+
1179
+ tokenizer = self.tokenizer
1180
+ args = self.args
1181
+
1182
+ if not no_cache:
1183
+ no_cache = args.no_cache
1184
+
1185
+ if not no_cache:
1186
+ os.makedirs(self.args.cache_dir, exist_ok=True)
1187
+
1188
+ mode = "dev" if evaluate else "train"
1189
+
1190
+ if self.args.use_hf_datasets:
1191
+ dataset = load_hf_dataset(data, tokenizer, self.args)
1192
+ return dataset
1193
+ elif args.dataset_class:
1194
+ CustomDataset = args.dataset_class
1195
+ return CustomDataset(tokenizer, args, data, mode)
1196
+ else:
1197
+ return T5Dataset(
1198
+ tokenizer,
1199
+ self.args,
1200
+ data,
1201
+ mode,
1202
+ )
1203
+
1204
+ def _create_training_progress_scores(self, **kwargs):
1205
+ extra_metrics = {key: [] for key in kwargs}
1206
+ training_progress_scores = {
1207
+ "global_step": [],
1208
+ "eval_loss": [],
1209
+ "train_loss": [],
1210
+ **extra_metrics,
1211
+ }
1212
+
1213
+ return training_progress_scores
1214
+
1215
+ def _get_last_metrics(self, metric_values):
1216
+ return {metric: values[-1] for metric, values in metric_values.items()}
1217
+
1218
+ def save_model(
1219
+ self, output_dir=None, optimizer=None, scheduler=None, model=None, results=None
1220
+ ):
1221
+ if not output_dir:
1222
+ output_dir = self.args.output_dir
1223
+ os.makedirs(output_dir, exist_ok=True)
1224
+
1225
+ if model and not self.args.no_save:
1226
+ # Take care of distributed/parallel training
1227
+ model_to_save = model.module if hasattr(model, "module") else model
1228
+ model_to_save.save_pretrained(output_dir)
1229
+ self.tokenizer.save_pretrained(output_dir)
1230
+ torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
1231
+ if optimizer and scheduler and self.args.save_optimizer_and_scheduler:
1232
+ torch.save(
1233
+ optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")
1234
+ )
1235
+ torch.save(
1236
+ scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")
1237
+ )
1238
+ self.save_model_args(output_dir)
1239
+
1240
+ if results:
1241
+ output_eval_file = os.path.join(output_dir, "eval_results.txt")
1242
+ with open(output_eval_file, "w") as writer:
1243
+ for key in sorted(results.keys()):
1244
+ writer.write("{} = {}\n".format(key, str(results[key])))
1245
+
1246
+ def save_model_args(self, output_dir):
1247
+ os.makedirs(output_dir, exist_ok=True)
1248
+ self.args.save(output_dir)
1249
+
1250
+ def _load_model_args(self, input_dir):
1251
+ args = T5Args()
1252
+ args.load(input_dir)
1253
+ return args
1254
+
1255
+ def get_named_parameters(self):
1256
+ return [n for n, p in self.model.named_parameters()]
t5/t5_utils.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ @author:XuMing([email protected])
4
+ @description: adjust for chinese tokenizer
5
+ """
6
+ import os
7
+ import pickle
8
+ from multiprocessing import Pool
9
+
10
+ from datasets import Dataset as HFDataset
11
+ from datasets import load_dataset
12
+ from torch.utils.data import Dataset
13
+ from tqdm.auto import tqdm
14
+ from rouge import Rouge
15
+ from loguru import logger
16
+
17
+
18
+ def preprocess_batch_for_hf_dataset(dataset, tokenizer, args):
19
+ if args.preprocess_inputs:
20
+ return tokenizer.prepare_seq2seq_batch(
21
+ src_texts=[
22
+ prefix + ": " + input_text
23
+ for prefix, input_text in zip(dataset["prefix"], dataset["input_text"])
24
+ ],
25
+ tgt_texts=dataset["target_text"],
26
+ max_length=args.max_seq_length,
27
+ max_target_length=args.max_length,
28
+ padding="max_length",
29
+ return_tensors="np",
30
+ truncation=True,
31
+ )
32
+ else:
33
+ return tokenizer.prepare_seq2seq_batch(
34
+ src_texts=[
35
+ prefix + input_text
36
+ for prefix, input_text in zip(dataset["prefix"], dataset["input_text"])
37
+ ],
38
+ tgt_texts=dataset["target_text"],
39
+ max_length=args.max_seq_length,
40
+ max_target_length=args.max_length,
41
+ padding="max_length",
42
+ return_tensors="np",
43
+ truncation=True,
44
+ )
45
+
46
+
47
+ def load_hf_dataset(data, tokenizer, args):
48
+ if isinstance(data, str):
49
+ dataset = load_dataset(
50
+ "csv",
51
+ data_files=data,
52
+ delimiter="\t",
53
+ download_mode="force_redownload"
54
+ if args.reprocess_input_data
55
+ else "reuse_dataset_if_exists",
56
+ )
57
+ else:
58
+ dataset = HFDataset.from_pandas(data)
59
+
60
+ dataset = dataset.map(
61
+ lambda x: preprocess_batch_for_hf_dataset(x, tokenizer=tokenizer, args=args),
62
+ batched=True,
63
+ )
64
+
65
+ dataset.set_format(type="pt", columns=["input_ids", "attention_mask"])
66
+
67
+ if isinstance(data, str):
68
+ # This is not necessarily a train dataset. The datasets library insists on calling it train.
69
+ return dataset["train"]
70
+ else:
71
+ return dataset
72
+
73
+
74
+ def preprocess_data(data):
75
+ prefix, input_text, target_text, tokenizer, args = data
76
+
77
+ # Add EOS again if truncated?
78
+ if args.preprocess_inputs:
79
+ batch = tokenizer.prepare_seq2seq_batch(
80
+ src_texts=[prefix + ": " + input_text],
81
+ tgt_texts=[target_text],
82
+ max_length=args.max_seq_length,
83
+ padding="max_length",
84
+ return_tensors="pt",
85
+ truncation=True,
86
+ )
87
+ else:
88
+ batch = tokenizer.prepare_seq2seq_batch(
89
+ src_texts=[prefix + ": " + input_text],
90
+ tgt_texts=[target_text],
91
+ max_length=args.max_seq_length,
92
+ padding="max_length",
93
+ return_tensors="pt",
94
+ truncation=True,
95
+ )
96
+ input_ids = batch["input_ids"][0]
97
+ attention_mask = batch["attention_mask"][0]
98
+ labels = batch["labels"][0]
99
+ return (input_ids, attention_mask, labels)
100
+
101
+
102
+ class T5Dataset(Dataset):
103
+ def __init__(self, tokenizer, args, data, mode):
104
+ cached_features_file = os.path.join(
105
+ args.cache_dir,
106
+ args.model_name.replace("/", "_")
107
+ + "_cached_"
108
+ + str(args.max_seq_length)
109
+ + str(len(data)),
110
+ )
111
+
112
+ if os.path.exists(cached_features_file) and (
113
+ (not args.reprocess_input_data and not args.no_cache)
114
+ or (mode == "dev" and args.use_cached_eval_features and not args.no_cache)
115
+ ):
116
+ logger.info(" Loading features from cached file %s" % cached_features_file)
117
+ with open(cached_features_file, "rb") as handle:
118
+ self.examples = pickle.load(handle)
119
+ else:
120
+ logger.info(" Creating features from dataset file at %s" % args.cache_dir)
121
+
122
+ data = [
123
+ (prefix, input_text, target_text, tokenizer, args)
124
+ for prefix, input_text, target_text in zip(
125
+ data["prefix"], data["input_text"], data["target_text"]
126
+ )
127
+ ]
128
+
129
+ if (mode == "train" and args.use_multiprocessing) or (
130
+ mode == "dev" and args.use_multiprocessing_for_evaluation
131
+ ):
132
+ if args.multiprocessing_chunksize == -1:
133
+ chunksize = max(len(data) // (args.process_count * 2), 500)
134
+ else:
135
+ chunksize = args.multiprocessing_chunksize
136
+
137
+ with Pool(args.process_count) as p:
138
+ self.examples = list(
139
+ tqdm(
140
+ p.imap(preprocess_data, data, chunksize=chunksize),
141
+ total=len(data),
142
+ disable=args.silent,
143
+ )
144
+ )
145
+ else:
146
+ self.examples = [preprocess_data(d) for d in tqdm(data, disable=args.silent)]
147
+ if not args.no_cache:
148
+ logger.info(" Saving features into cached file %s" % cached_features_file)
149
+ with open(cached_features_file, "wb") as handle:
150
+ pickle.dump(self.examples, handle, protocol=pickle.HIGHEST_PROTOCOL)
151
+
152
+ def __len__(self):
153
+ return len(self.examples)
154
+
155
+ def __getitem__(self, index):
156
+ return self.examples[index]
157
+
158
+
159
+ def dynamic_lcs(X, Y):
160
+ # find the length of the strings
161
+ m = len(X)
162
+ n = len(Y)
163
+
164
+ # declaring the array for storing the dp values
165
+ L = [[None] * (n + 1) for i in range(m + 1)]
166
+
167
+ """Following steps build L[m + 1][n + 1] in bottom up fashion
168
+ Note: L[i][j] contains length of LCS of X[0..i-1]
169
+ and Y[0..j-1]"""
170
+ for i in range(m + 1):
171
+ for j in range(n + 1):
172
+ if i == 0 or j == 0:
173
+ L[i][j] = 0
174
+ elif X[i - 1] == Y[j - 1]:
175
+ L[i][j] = L[i - 1][j - 1] + 1
176
+ else:
177
+ L[i][j] = max(L[i - 1][j], L[i][j - 1])
178
+
179
+ # L[m][n] contains the length of LCS of X[0..n-1] & Y[0..m-1]
180
+ return L[m][n]
181
+
182
+
183
+ def f1_sim(text_a, text_b):
184
+ """F1相似度
185
+ 说明:算出两个文本的最长公共子序列长度,然后乘2并处以两者
186
+ 长度之和。
187
+ 脚本见:https://github.com/CLUEbenchmark/pCLUE/blob/main/evaluate_pclue.py
188
+ 计算pCLUE任务总分,及子分数
189
+ """
190
+ if not text_a and not text_b:
191
+ return 0.
192
+ lcs_len = dynamic_lcs(text_a, text_b)
193
+ return 2. * lcs_len / (len(text_a) + len(text_b))
194
+
195
+
196
+ def rouge_l_zh(target, pred):
197
+ """计算Rouge-l得分,Rouge-l指标常用于评估自动文本摘要及翻译任务
198
+ target: 真实标签
199
+ pred: 预测标签"""
200
+
201
+ if not (isinstance(target, str) or isinstance(pred, str)):
202
+ logger.error("target或pred为非字符串!请检查!")
203
+ return 0
204
+ rouge = Rouge()
205
+ scores = rouge.get_scores(" ".join(list(pred)), " ".join(list(target)))
206
+ score = scores[0]["rouge-l"]
207
+ return score["f"]
208
+
209
+
210
+ if __name__ == '__main__':
211
+ a = '123444'
212
+ b = '23411'
213
+ print(f1_sim(a, b))
214
+ print(dynamic_lcs(a, b))