|
import os |
|
import time |
|
|
|
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" |
|
|
|
from vllm import LLM, SamplingParams |
|
import torch |
|
from cog import BasePredictor, Input, ConcatenateIterator |
|
import typing as t |
|
|
|
|
|
MODEL_ID = "TheBloke/Mistral-7B-OpenOrca-AWQ" |
|
PROMPT_TEMPLATE = """\ |
|
<|im_start|>system |
|
You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers! |
|
<|im_end|> |
|
<|im_start|>user |
|
{prompt}<|im_end|> |
|
<|im_start|>assistant |
|
""" |
|
|
|
DEFAULT_MAX_NEW_TOKENS = 512 |
|
DEFAULT_TEMPERATURE = 0.8 |
|
DEFAULT_TOP_P = 0.95 |
|
DEFAULT_TOP_K = 50 |
|
DEFAULT_PRESENCE_PENALTY = 0.0 |
|
DEFAULT_FREQUENCY_PENALTY = 0.0 |
|
|
|
|
|
def vllm_generate_iterator( |
|
self, prompt: str, /, *, echo: bool = False, stop: str = None, stop_token_ids: t.List[int] = None, sampling_params=None, **attrs: t.Any |
|
) -> t.Iterator[t.Dict[str, t.Any]]: |
|
request_id: str = attrs.pop('request_id', None) |
|
if request_id is None: raise ValueError('request_id must not be None.') |
|
if stop_token_ids is None: stop_token_ids = [] |
|
stop_token_ids.append(self.tokenizer.eos_token_id) |
|
stop_ = set() |
|
if isinstance(stop, str) and stop != '': stop_.add(stop) |
|
elif isinstance(stop, list) and stop != []: stop_.update(stop) |
|
for tid in stop_token_ids: |
|
if tid: stop_.add(self.tokenizer.decode(tid)) |
|
|
|
|
|
|
|
|
|
self.add_request(request_id=request_id, prompt=prompt, sampling_params=sampling_params) |
|
|
|
token_cache = [] |
|
print_len = 0 |
|
|
|
while self.has_unfinished_requests(): |
|
for request_output in self.step(): |
|
|
|
for output in request_output.outputs: |
|
text = output.text |
|
yield {'text': text, 'error_code': 0, 'num_tokens': len(output.token_ids)} |
|
|
|
if request_output.finished: break |
|
|
|
|
|
class Predictor(BasePredictor): |
|
|
|
def setup(self): |
|
self.llm = LLM( |
|
model=MODEL_ID, |
|
quantization="awq", |
|
dtype="float16" |
|
) |
|
|
|
def predict( |
|
self, |
|
prompt: str, |
|
max_new_tokens: int = Input( |
|
description="The maximum number of tokens the model should generate as output.", |
|
default=DEFAULT_MAX_NEW_TOKENS, |
|
), |
|
temperature: float = Input( |
|
description="The value used to modulate the next token probabilities.", default=DEFAULT_TEMPERATURE |
|
), |
|
top_p: float = Input( |
|
description="A probability threshold for generating the output. If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751).", |
|
default=DEFAULT_TOP_P, |
|
), |
|
top_k: int = Input( |
|
description="The number of highest probability tokens to consider for generating the output. If > 0, only keep the top k tokens with highest probability (top-k filtering).", |
|
default=DEFAULT_TOP_K, |
|
), |
|
presence_penalty: float = Input( |
|
description="Presence penalty", |
|
default=DEFAULT_PRESENCE_PENALTY, |
|
), |
|
frequency_penalty: float = Input( |
|
description="Frequency penalty", |
|
default=DEFAULT_FREQUENCY_PENALTY, |
|
), |
|
prompt_template: str = Input( |
|
description="The template used to format the prompt. The input prompt is inserted into the template using the `{prompt}` placeholder.", |
|
default=PROMPT_TEMPLATE, |
|
) |
|
) -> ConcatenateIterator: |
|
prompts = [ |
|
( |
|
prompt_template.format(prompt=prompt), |
|
SamplingParams( |
|
max_tokens=max_new_tokens, |
|
temperature=temperature, |
|
top_k=top_k, |
|
top_p=top_p, |
|
presence_penalty=presence_penalty, |
|
frequency_penalty=frequency_penalty |
|
) |
|
) |
|
] |
|
start = time.time() |
|
while True: |
|
if prompts: |
|
prompt, sampling_params = prompts.pop(0) |
|
gen = vllm_generate_iterator(self.llm.llm_engine, prompt, echo=False, stop=None, stop_token_ids=None, sampling_params=sampling_params, request_id=0) |
|
last = "" |
|
for _, x in enumerate(gen): |
|
if x['text'] == "": |
|
continue |
|
yield x['text'][len(last):] |
|
last = x["text"] |
|
num_tokens = x["num_tokens"] |
|
print(f"\nGenerated {num_tokens} tokens in {time.time() - start} seconds.") |
|
|
|
if not (self.llm.llm_engine.has_unfinished_requests() or prompts): |
|
break |
|
|
|
|
|
if __name__ == '__main__': |
|
import sys |
|
p = Predictor() |
|
p.setup() |
|
gen = p.predict( |
|
"Write me an itinerary for my dog's birthday party.", |
|
512, |
|
0.8, |
|
0.95, |
|
50, |
|
1.0, |
|
0.2, |
|
PROMPT_TEMPLATE, |
|
) |
|
for out in gen: |
|
print(out, end="") |
|
sys.stdout.flush() |