Arivmta19 commited on
Commit
5bfaa4d
1 Parent(s): 3d513f6

Create predict.py

Browse files
Files changed (1) hide show
  1. predict.py +148 -0
predict.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+
4
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
5
+
6
+ from vllm import LLM, SamplingParams
7
+ import torch
8
+ from cog import BasePredictor, Input, ConcatenateIterator
9
+ import typing as t
10
+
11
+
12
+ MODEL_ID = "TheBloke/Mistral-7B-OpenOrca-AWQ"
13
+ PROMPT_TEMPLATE = """\
14
+ <|im_start|>system
15
+ You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!
16
+ <|im_end|>
17
+ <|im_start|>user
18
+ {prompt}<|im_end|>
19
+ <|im_start|>assistant
20
+ """
21
+
22
+ DEFAULT_MAX_NEW_TOKENS = 512
23
+ DEFAULT_TEMPERATURE = 0.8
24
+ DEFAULT_TOP_P = 0.95
25
+ DEFAULT_TOP_K = 50
26
+ DEFAULT_PRESENCE_PENALTY = 0.0 # 1.15
27
+ DEFAULT_FREQUENCY_PENALTY = 0.0 # 0.2
28
+
29
+
30
+ def vllm_generate_iterator(
31
+ self, prompt: str, /, *, echo: bool = False, stop: str = None, stop_token_ids: t.List[int] = None, sampling_params=None, **attrs: t.Any
32
+ ) -> t.Iterator[t.Dict[str, t.Any]]:
33
+ request_id: str = attrs.pop('request_id', None)
34
+ if request_id is None: raise ValueError('request_id must not be None.')
35
+ if stop_token_ids is None: stop_token_ids = []
36
+ stop_token_ids.append(self.tokenizer.eos_token_id)
37
+ stop_ = set()
38
+ if isinstance(stop, str) and stop != '': stop_.add(stop)
39
+ elif isinstance(stop, list) and stop != []: stop_.update(stop)
40
+ for tid in stop_token_ids:
41
+ if tid: stop_.add(self.tokenizer.decode(tid))
42
+
43
+ # if self.config['temperature'] <= 1e-5: top_p = 1.0
44
+ # else: top_p = self.config['top_p']
45
+ # config = self.config.model_construct_env(stop=list(stop_), top_p=top_p, **attrs)
46
+ self.add_request(request_id=request_id, prompt=prompt, sampling_params=sampling_params)
47
+
48
+ token_cache = []
49
+ print_len = 0
50
+
51
+ while self.has_unfinished_requests():
52
+ for request_output in self.step():
53
+ # Add the new tokens to the cache
54
+ for output in request_output.outputs:
55
+ text = output.text
56
+ yield {'text': text, 'error_code': 0, 'num_tokens': len(output.token_ids)}
57
+
58
+ if request_output.finished: break
59
+
60
+
61
+ class Predictor(BasePredictor):
62
+
63
+ def setup(self):
64
+ self.llm = LLM(
65
+ model=MODEL_ID,
66
+ quantization="awq",
67
+ dtype="float16"
68
+ )
69
+
70
+ def predict(
71
+ self,
72
+ prompt: str,
73
+ max_new_tokens: int = Input(
74
+ description="The maximum number of tokens the model should generate as output.",
75
+ default=DEFAULT_MAX_NEW_TOKENS,
76
+ ),
77
+ temperature: float = Input(
78
+ description="The value used to modulate the next token probabilities.", default=DEFAULT_TEMPERATURE
79
+ ),
80
+ top_p: float = Input(
81
+ description="A probability threshold for generating the output. If < 1.0, only keep the top tokens with cumulative probability >= top_p (nucleus filtering). Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751).",
82
+ default=DEFAULT_TOP_P,
83
+ ),
84
+ top_k: int = Input(
85
+ description="The number of highest probability tokens to consider for generating the output. If > 0, only keep the top k tokens with highest probability (top-k filtering).",
86
+ default=DEFAULT_TOP_K,
87
+ ),
88
+ presence_penalty: float = Input(
89
+ description="Presence penalty",
90
+ default=DEFAULT_PRESENCE_PENALTY,
91
+ ),
92
+ frequency_penalty: float = Input(
93
+ description="Frequency penalty",
94
+ default=DEFAULT_FREQUENCY_PENALTY,
95
+ ),
96
+ prompt_template: str = Input(
97
+ description="The template used to format the prompt. The input prompt is inserted into the template using the `{prompt}` placeholder.",
98
+ default=PROMPT_TEMPLATE,
99
+ )
100
+ ) -> ConcatenateIterator:
101
+ prompts = [
102
+ (
103
+ prompt_template.format(prompt=prompt),
104
+ SamplingParams(
105
+ max_tokens=max_new_tokens,
106
+ temperature=temperature,
107
+ top_k=top_k,
108
+ top_p=top_p,
109
+ presence_penalty=presence_penalty,
110
+ frequency_penalty=frequency_penalty
111
+ )
112
+ )
113
+ ]
114
+ start = time.time()
115
+ while True:
116
+ if prompts:
117
+ prompt, sampling_params = prompts.pop(0)
118
+ gen = vllm_generate_iterator(self.llm.llm_engine, prompt, echo=False, stop=None, stop_token_ids=None, sampling_params=sampling_params, request_id=0)
119
+ last = ""
120
+ for _, x in enumerate(gen):
121
+ if x['text'] == "":
122
+ continue
123
+ yield x['text'][len(last):]
124
+ last = x["text"]
125
+ num_tokens = x["num_tokens"]
126
+ print(f"\nGenerated {num_tokens} tokens in {time.time() - start} seconds.")
127
+
128
+ if not (self.llm.llm_engine.has_unfinished_requests() or prompts):
129
+ break
130
+
131
+
132
+ if __name__ == '__main__':
133
+ import sys
134
+ p = Predictor()
135
+ p.setup()
136
+ gen = p.predict(
137
+ "Write me an itinerary for my dog's birthday party.",
138
+ 512,
139
+ 0.8,
140
+ 0.95,
141
+ 50,
142
+ 1.0,
143
+ 0.2,
144
+ PROMPT_TEMPLATE,
145
+ )
146
+ for out in gen:
147
+ print(out, end="")
148
+ sys.stdout.flush()