|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig |
|
from peft import PeftModel, PeftConfig |
|
from model import Model |
|
|
|
class KoAlpaca(Model): |
|
def __init__(self): |
|
peft_model_id = "4n3mone/Komuchat-koalpaca-polyglot-12.8B" |
|
config = PeftConfig.from_pretrained(peft_model_id) |
|
accelerator = Accelerator() |
|
self.bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
|
|
self.model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, quantization_config=self.bnb_config, device_map='auto') |
|
self.model = PeftModel.from_pretrained(self.model, peft_model_id) |
|
self.tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path) |
|
self.gen_config = GenerationConfig.from_pretrained('./models/koalpaca', 'gen_config.json') |
|
self.INPUT_FORMAT = "### 질문: <INPUT>\n\n### 답변:" |
|
self.model.eval() |
|
|
|
def generate(self, inputs): |
|
inputs = self.INPUT_FORMAT.replace('<INPUT>', inputs) |
|
output_ids = self.model.generate( |
|
**self.tokenizer( |
|
inputs, |
|
return_tensors='pt', |
|
return_token_type_ids=False |
|
).to(accelerator.device), |
|
generation_config=self.gen_config |
|
) |
|
outputs = self.tokenizer.decode(output_ids[0]).split("### 답변: ")[-1] |
|
return outputs |