File size: 4,025 Bytes
6ea7bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3367f4b
6ea7bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4a047b7
 
 
 
6ea7bd8
4a047b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c676250
4a047b7
6ea7bd8
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import gc
import torch

import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from trl import DPOTrainer
import bitsandbytes as bnb
import wandb

# Defined in the secrets tab in Google Colab
# wb_token = "2eae619e4d6f0caef6408a6dc869dd0bfa6595f6"
hf_token = os.getenv("hf_token")
wb_token = os.getenv("wb_token")
wandb.login(key=wb_token)



# Fine-tune model with DPO


import gradio as gr


def greet(traindata_,output_repo):
    model_name = "HuggingFaceH4/zephyr-7b-gemma-v0.1"
    # new_model = "Gopal2002/zehpyr-gemma-dpo-finetune"
    new_model = output_repo

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.padding_side = "left"
    
    
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            load_in_4bit=True
        )
        model.config.use_cache = False
    
        # Reference model
        ref_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            load_in_4bit=True
        )
    
        # specify how to quantize the model
        quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_quant_type="nf4",
                    bnb_4bit_compute_dtype=torch.bfloat16,
        )
        device_map = {"": torch.cuda.current_device()} if torch.cuda.is_available() else None
    
        # Step 1: load the base model (Mistral-7B in our case) in 4-bit
        model_kwargs = dict(
            # attn_implementation="flash_attention_2", # set this to True if your GPU supports it (Flash Attention drastically speeds up model computations)
            torch_dtype="auto",
            use_cache=False,  # set to False as we're going to use gradient checkpointing
            device_map=device_map,
            quantization_config=quantization_config,
        )
        model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)
    
    # Training arguments
        peft_config = LoraConfig(
            r=16,
            lora_alpha=16,
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']
        )
        training_args = TrainingArguments(
            per_device_train_batch_size=4,
            gradient_accumulation_steps=4,
            gradient_checkpointing=True,
            learning_rate=5e-5,
            lr_scheduler_type="cosine",
            max_steps=200,
            save_strategy="no",
            logging_steps=1,
            output_dir=new_model,
            optim="paged_adamw_32bit",
            warmup_steps=100,
            bf16=True,
            report_to="wandb",
        )
    
    #load the dataset
        dataset = load_dataset(traindata_, split='train')
        
        # dataset = load_dataset('Gopal2002/zephyr-gemma-finetune-dpo', split='train')
    
    # Create DPO trainer
        dpo_trainer = DPOTrainer(
            model,
            ref_model=None,
            args=training_args,
            train_dataset=dataset,
            tokenizer=tokenizer,
            peft_config=peft_config,
            beta=0.1,
            max_prompt_length=2048,
            max_length=1536,
        )
        dpo_trainer.train()
        return "Training Done"
    except Exception as e:
        return str(e)


with gr.Blocks() as demo:
    traindata_ = gr.Textbox(label="Enter training data repo")
    output_repo = gr.Textbox(label="Enter output model path")

    output = gr.Textbox(label="Output Box")
    greet_btn = gr.Button("TRAIN")
    greet_btn.click(fn=greet, inputs=[traindata_,output_repo], outputs=output, api_name="greet")

demo.launch()