File size: 6,686 Bytes
05f7150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch, argparse, copy
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
from marlin import Layer as MarlinLayer
import gc

parser = argparse.ArgumentParser()
parser.add_argument("--model-id", type=str)
parser.add_argument("--save-path", type=str)
parser.add_argument("--do-generation", action="store_true")

def _validate_compatibility(model):
    if not hasattr(model.config, "quantization_config"):
        raise ValueError("Must be a quantized model to convert to Marlin Format")
    quantization_config = model.config.quantization_config
    if quantization_config.quant_method != "gptq":
        raise ValueError(f"Only GPTQ models can be converted to Marlin format. You passed a model with quant_method={quantization_config.quant_method}")
    if quantization_config.bits != 4:
        raise ValueError(f"Only 4 bit quantized models can be converted to Marlin format. You passed a model with bits={quantization_config.bits}")
    if quantization_config.group_size != 128:
        raise ValueError(f"Only group size 128 models can be converted to Marlin format. You passed a model with group_size={quantization_config.group_size}")
    if not quantization_config.sym:
        raise ValueError(f"Only models with symmetric quantization can be converted to Marlin Format. You passed a model with sym={quantization_config.sym}")
    if quantization_config.desc_act:
        raise ValueError(f"Models with act order quantization cannot be converted to Marlin Format. You passed a model with desc_act={quantization_config.desc_act}")

@torch.no_grad()
def unpack_4bit_to_32bit_signed(qweight, qzeros):
    # Unpack 4-bit values and interpret them as signed integers
    unpacked_weights = torch.zeros((qweight.shape[0]*8, qweight.shape[1]), dtype=torch.int8, device=qweight.device, requires_grad=False)
    unpacked_zeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*8), dtype=torch.int8, device=qzeros.device, requires_grad=False)

    for row in range(unpacked_weights.shape[0]):
        i = row % 8
        unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF

    for col in range(unpacked_zeros.shape[1]):
        i = col % 8
        unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF

    return unpacked_weights, unpacked_zeros + 1

@torch.no_grad()
def dequantize_weight(layer):
    qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales
    unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros)
    group_size = unpacked_qweight.shape[0] // scales.shape[0]
    scales = scales.repeat_interleave(group_size, dim=0)
    unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
    unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales

    return unpacked_qweight.T

@torch.no_grad()
def convert_model(model, verbose=True):
    for name, module in model.named_modules():
        if not isinstance(module, QuantLinear):
            continue

        if verbose:
            print(f"--- Converting Module: {name}")
        parent_name = ".".join(name.split(".")[:-1])
        layer_name = name[len(parent_name) + 1:]

        # Dequantize the weight.
        dequantized_weight = dequantize_weight(module).to(torch.float16)
        linear_module = torch.nn.Linear(
            in_features=dequantized_weight.shape[1],
            out_features=dequantized_weight.shape[0],
            bias=False,
            dtype=torch.float16,
            device="cuda")
        linear_module.weight.data.copy_(dequantized_weight)

        # Create new linear method and copy to model.
        new_module = MarlinLayer(
            infeatures=linear_module.in_features,
            outfeatures=linear_module.out_features,
            groupsize=model.config.quantization_config.group_size)
        new_module.pack(linear_module, scales=copy.deepcopy(module.scales.data.t()))

        # Save to parent.
        parent_module = model.get_submodule(parent_name)
        setattr(parent_module, layer_name, new_module)

        # Free cuda memory.
        del dequantized_weight, module
        torch.cuda.empty_cache()
        gc.collect()

    return model

@torch.no_grad()
def dequantize_model(model, verbose=True):
    for name, module in model.named_modules():
        if not isinstance(module, QuantLinear):
            continue

        if verbose:
            print(f"--- Dequantizing Module: {name}")
        parent_name = ".".join(name.split(".")[:-1])
        layer_name = name[len(parent_name) + 1:]

        # Dequantize the weight.
        dequantized_weight = dequantize_weight(module)
        dequantized_weight_cpu = dequantized_weight.to("cpu")

        # Create new linear method and copy to model.
        new_module = torch.nn.Linear(
            in_features=dequantized_weight_cpu.shape[1],
            out_features=dequantized_weight_cpu.shape[0],
            bias=False,
            dtype=torch.float16)
        new_module.weight.data.copy_(dequantized_weight_cpu)
        new_module.scales = torch.nn.Parameter(copy.deepcopy(module.scales.data))
        
        # Save to parent.
        parent_module = model.get_submodule(parent_name)
        setattr(parent_module, layer_name, new_module)

        # Free cuda memory.
        del dequantized_weight, dequantized_weight_cpu, module
        torch.cuda.empty_cache()

    return model

if __name__ == "__main__":
    args = parser.parse_args()
    model_id = args.model_id
    save_path = args.save_path
    do_generation = args.do_generation
    
    print("Loading gptq model...")
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    # Validate that this model is compatible with Marlin.
    print("Validating compatibility...")
    _validate_compatibility(model)

    # Dequantize the Model.
    print("Converting model...")
    model = convert_model(model).to("cpu")

    # Save after updating quantization config.
    print("Saving marlin model...")
    model.config.quantization_config = {
        "group_size": model.config.quantization_config.group_size,
        "quant_method": "marlin"
    }
    model.save_pretrained(save_path)
    tokenizer.save_pretrained(save_path)

    if do_generation:
        print("Generating sample text...")
        model.to("cuda")
        prompt = "My favorite song is"
        inputs = tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to("cuda") for k, v in inputs.items()}
        outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False)
        print(tokenizer.batch_decode(outputs)[0])