File size: 6,686 Bytes
05f7150 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import torch, argparse, copy
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_gptq.nn_modules.qlinear.qlinear_exllama import QuantLinear
from marlin import Layer as MarlinLayer
import gc
parser = argparse.ArgumentParser()
parser.add_argument("--model-id", type=str)
parser.add_argument("--save-path", type=str)
parser.add_argument("--do-generation", action="store_true")
def _validate_compatibility(model):
if not hasattr(model.config, "quantization_config"):
raise ValueError("Must be a quantized model to convert to Marlin Format")
quantization_config = model.config.quantization_config
if quantization_config.quant_method != "gptq":
raise ValueError(f"Only GPTQ models can be converted to Marlin format. You passed a model with quant_method={quantization_config.quant_method}")
if quantization_config.bits != 4:
raise ValueError(f"Only 4 bit quantized models can be converted to Marlin format. You passed a model with bits={quantization_config.bits}")
if quantization_config.group_size != 128:
raise ValueError(f"Only group size 128 models can be converted to Marlin format. You passed a model with group_size={quantization_config.group_size}")
if not quantization_config.sym:
raise ValueError(f"Only models with symmetric quantization can be converted to Marlin Format. You passed a model with sym={quantization_config.sym}")
if quantization_config.desc_act:
raise ValueError(f"Models with act order quantization cannot be converted to Marlin Format. You passed a model with desc_act={quantization_config.desc_act}")
@torch.no_grad()
def unpack_4bit_to_32bit_signed(qweight, qzeros):
# Unpack 4-bit values and interpret them as signed integers
unpacked_weights = torch.zeros((qweight.shape[0]*8, qweight.shape[1]), dtype=torch.int8, device=qweight.device, requires_grad=False)
unpacked_zeros = torch.zeros((qzeros.shape[0], qzeros.shape[1]*8), dtype=torch.int8, device=qzeros.device, requires_grad=False)
for row in range(unpacked_weights.shape[0]):
i = row % 8
unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF
for col in range(unpacked_zeros.shape[1]):
i = col % 8
unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF
return unpacked_weights, unpacked_zeros + 1
@torch.no_grad()
def dequantize_weight(layer):
qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales
unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros)
group_size = unpacked_qweight.shape[0] // scales.shape[0]
scales = scales.repeat_interleave(group_size, dim=0)
unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales
return unpacked_qweight.T
@torch.no_grad()
def convert_model(model, verbose=True):
for name, module in model.named_modules():
if not isinstance(module, QuantLinear):
continue
if verbose:
print(f"--- Converting Module: {name}")
parent_name = ".".join(name.split(".")[:-1])
layer_name = name[len(parent_name) + 1:]
# Dequantize the weight.
dequantized_weight = dequantize_weight(module).to(torch.float16)
linear_module = torch.nn.Linear(
in_features=dequantized_weight.shape[1],
out_features=dequantized_weight.shape[0],
bias=False,
dtype=torch.float16,
device="cuda")
linear_module.weight.data.copy_(dequantized_weight)
# Create new linear method and copy to model.
new_module = MarlinLayer(
infeatures=linear_module.in_features,
outfeatures=linear_module.out_features,
groupsize=model.config.quantization_config.group_size)
new_module.pack(linear_module, scales=copy.deepcopy(module.scales.data.t()))
# Save to parent.
parent_module = model.get_submodule(parent_name)
setattr(parent_module, layer_name, new_module)
# Free cuda memory.
del dequantized_weight, module
torch.cuda.empty_cache()
gc.collect()
return model
@torch.no_grad()
def dequantize_model(model, verbose=True):
for name, module in model.named_modules():
if not isinstance(module, QuantLinear):
continue
if verbose:
print(f"--- Dequantizing Module: {name}")
parent_name = ".".join(name.split(".")[:-1])
layer_name = name[len(parent_name) + 1:]
# Dequantize the weight.
dequantized_weight = dequantize_weight(module)
dequantized_weight_cpu = dequantized_weight.to("cpu")
# Create new linear method and copy to model.
new_module = torch.nn.Linear(
in_features=dequantized_weight_cpu.shape[1],
out_features=dequantized_weight_cpu.shape[0],
bias=False,
dtype=torch.float16)
new_module.weight.data.copy_(dequantized_weight_cpu)
new_module.scales = torch.nn.Parameter(copy.deepcopy(module.scales.data))
# Save to parent.
parent_module = model.get_submodule(parent_name)
setattr(parent_module, layer_name, new_module)
# Free cuda memory.
del dequantized_weight, dequantized_weight_cpu, module
torch.cuda.empty_cache()
return model
if __name__ == "__main__":
args = parser.parse_args()
model_id = args.model_id
save_path = args.save_path
do_generation = args.do_generation
print("Loading gptq model...")
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Validate that this model is compatible with Marlin.
print("Validating compatibility...")
_validate_compatibility(model)
# Dequantize the Model.
print("Converting model...")
model = convert_model(model).to("cpu")
# Save after updating quantization config.
print("Saving marlin model...")
model.config.quantization_config = {
"group_size": model.config.quantization_config.group_size,
"quant_method": "marlin"
}
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
if do_generation:
print("Generating sample text...")
model.to("cuda")
prompt = "My favorite song is"
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to("cuda") for k, v in inputs.items()}
outputs = model.generate(**inputs, max_new_tokens=50, do_sample=False)
print(tokenizer.batch_decode(outputs)[0])
|