File size: 3,676 Bytes
6ea7bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a322426
6ea7bd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import gc
import torch

import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, BitsAndBytesConfig
from datasets import load_dataset
from peft import LoraConfig, PeftModel, get_peft_model, prepare_model_for_kbit_training
from trl import DPOTrainer
import bitsandbytes as bnb
import wandb

# Defined in the secrets tab in Google Colab
# wb_token = "2eae619e4d6f0caef6408a6dc869dd0bfa6595f6"
hf_token = os.getenv("hf_token")
wb_token = os.getenv("2eae619e4d6f0caef6408a6dc869dd0bfa6595f6")
wandb.login(key=wb_token)



# Fine-tune model with DPO


import gradio as gr


def greet(traindata_,output_repo):
    model_name = "HuggingFaceH4/zephyr-7b-gemma-v0.1"
    # new_model = "Gopal2002/zehpyr-gemma-dpo-finetune"
    new_model = output_repo


    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"


    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        load_in_4bit=True
    )
    model.config.use_cache = False

    # Reference model
    ref_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        load_in_4bit=True
    )

    # specify how to quantize the model
    quantization_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
    )
    device_map = {"": torch.cuda.current_device()} if torch.cuda.is_available() else None

    # Step 1: load the base model (Mistral-7B in our case) in 4-bit
    model_kwargs = dict(
        # attn_implementation="flash_attention_2", # set this to True if your GPU supports it (Flash Attention drastically speeds up model computations)
        torch_dtype="auto",
        use_cache=False,  # set to False as we're going to use gradient checkpointing
        device_map=device_map,
        quantization_config=quantization_config,
    )
    model = AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs)

# Training arguments
    peft_config = LoraConfig(
        r=16,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=['k_proj', 'gate_proj', 'v_proj', 'up_proj', 'q_proj', 'o_proj', 'down_proj']
    )
    training_args = TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        gradient_checkpointing=True,
        learning_rate=5e-5,
        lr_scheduler_type="cosine",
        max_steps=200,
        save_strategy="no",
        logging_steps=1,
        output_dir=new_model,
        optim="paged_adamw_32bit",
        warmup_steps=100,
        bf16=True,
        report_to="wandb",
    )

#load the dataset
    dataset = load_dataset(traindata_, split='train')
    
    # dataset = load_dataset('Gopal2002/zephyr-gemma-finetune-dpo', split='train')

# Create DPO trainer
    dpo_trainer = DPOTrainer(
        model,
        ref_model=None,
        args=training_args,
        train_dataset=dataset,
        tokenizer=tokenizer,
        peft_config=peft_config,
        beta=0.1,
        max_prompt_length=2048,
        max_length=1536,
    )
    dpo_trainer.train()
    return "Training Done"


with gr.Blocks() as demo:
    traindata_ = gr.Textbox(label="Enter training data repo")
    output_repo = gr.Textbox(label="Enter output model path")

    output = gr.Textbox(label="Output Box")
    greet_btn = gr.Button("TRAIN")
    greet_btn.click(fn=greet, inputs=[traindata_,output_repo], outputs=output, api_name="greet")

demo.launch()