gardarjuto's picture
fix max_length
6ce586a
raw
history blame
5.73 kB
from typing import Dict, List, Any
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import logging
logging.basicConfig(level=logging.INFO)
LOGGER = logging.getLogger(__name__)
# Prompts for the different tasks
START_PROMPT_TASK1 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
END_PROMPT_TASK1 = "Sérðu eitthvað sem mætti betur fara í textanum? Búðu til lista af öllum slíkum tilvikum þar sem hver lína tilgreinir hver villan er, hvar hún er, og hvað væri gert í staðinn fyrir villuna.\n\n"
START_PROMPT_TASK2 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.Ég er með tvær útgáfur af textanum, A og B, og önnur þeirra gæti verið betri en hin á einhvern hátt, t.d. hvað varðar stafsetningu, málfræði o.s.frv.\nHér er texti A:\n\n"
MIDDLE_PROMPT_TASK2 = "Hér er texti B:\n\n"
END_PROMPT_TASK2 = "Hvorn textann líst þér betur á?\n\n"
START_PROMPT_TASK3 = "Hér er texti sem ég vil að þú skoðir vel og vandlega. Þú skalt skoða hvert einasta orð, orðasamband, og setningu og meta hvort þér finnist eitthvað athugavert, til dæmis hvað varðar málfræði, stafsetningu, skringilega merkingu og svo framvegis.\nHér er textinn:\n\n"
END_PROMPT_TASK3 = "Reyndu nú að laga textann þannig að hann líti betur út, eins og þér finnst best við hæfi.\n\n"
START_PROMPT_TASK = {
1: START_PROMPT_TASK1,
2: START_PROMPT_TASK2,
3: START_PROMPT_TASK3,
}
END_PROMPT_TASK = {1: END_PROMPT_TASK1, 2: END_PROMPT_TASK2, 3: END_PROMPT_TASK3}
SEP = "\n\n"
class EndpointHandler:
def __init__(self, path=""):
self.model = AutoModelForCausalLM.from_pretrained(
path, device_map="auto", torch_dtype=torch.bfloat16
)
LOGGER.info(f"Inference model loaded from {path}")
LOGGER.info(f"Model device: {self.model.device}")
# Fix the pad and bos tokens to avoid bug in the tokenizer
pad_token = "<unk>"
bos_token = "<|endoftext|>"
self.tokenizer = AutoTokenizer.from_pretrained(
"AI-Sweden-Models/gpt-sw3-6.7b", pad_token=pad_token, bos_token=bos_token
)
def check_valid_inputs(
self, input_a: str, input_b: str, task: int
) -> bool:
"""
Check if the inputs are valid
"""
if task not in [1, 2, 3]:
return False
if task == 1 or task == 3:
if input_a is None:
return False
elif task == 2:
if input_a is None or input_b is None:
return False
return True
def tokenize_input(self, input_a: str, input_b: str, task: int) -> List[int]:
"""
Tokenize the input
"""
if task == 1 or task == 3:
tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"]
tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"]
tokenized_sentence = self.tokenizer(input_a + SEP)["input_ids"]
concatted_data = (
[self.tokenizer.bos_token_id]
+ tokenized_start
+ tokenized_sentence
+ tokenized_end
)
elif task == 2:
tokenized_start = self.tokenizer(START_PROMPT_TASK[task])["input_ids"]
tokenized_middle = self.tokenizer(MIDDLE_PROMPT_TASK2)["input_ids"]
tokenized_end = self.tokenizer(END_PROMPT_TASK[task])["input_ids"]
tokenized_sentence_a = self.tokenizer(input_a + SEP)["input_ids"]
tokenized_sentence_b = self.tokenizer(input_b + SEP)["input_ids"]
concatted_data = (
[self.tokenizer.bos_token_id]
+ tokenized_start
+ tokenized_sentence_a
+ tokenized_middle
+ tokenized_sentence_b
+ tokenized_end
)
return concatted_data
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
LOGGER.info(f"Received data: {data}")
# Get inputs
input_a = data.pop("input_a", None)
input_b = data.pop("input_b", None)
task = data.pop("task", None)
parameters = data.pop("parameters", {})
# Check valid inputs
if not self.check_valid_inputs(input_a, input_b, task):
return [{"error": "Invalid inputs"}]
if "max_new_tokens" not in parameters and "max_length" not in parameters:
parameters["max_new_tokens"] = 512
# Tokenize the input
tokenized_input = self.tokenize_input(input_a, input_b, task)
# Move the input to the device
input_ids = torch.tensor(tokenized_input).to(self.model.device)
input_ids = input_ids.unsqueeze(0)
# Generate the output
output = self.model.generate(input_ids, **parameters)
# Decode only the new part of the output
decoded_output = self.tokenizer.decode(
output[0][len(tokenized_input) :], skip_special_tokens=True
).strip()
return [{"output": decoded_output}]