Jeff28 commited on
Commit
7686e2d
1 Parent(s): 1a2d1dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -144
app.py CHANGED
@@ -1,172 +1,138 @@
1
- import spaces
2
  import os
3
- import json
4
- import subprocess
5
- from llama_cpp import Llama
6
- from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
7
- from llama_cpp_agent.providers import LlamaCppPythonProvider
8
- from llama_cpp_agent.chat_history import BasicChatHistory
9
- from llama_cpp_agent.chat_history.messages import Roles
10
- import gradio as gr
11
- from huggingface_hub import hf_hub_download
12
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
15
-
16
- hf_hub_download(
17
- repo_id="bartowski/gemma-2-9b-it-GGUF",
18
- filename="gemma-2-9b-it-Q5_K_M.gguf",
19
- local_dir="./models"
20
- )
21
 
22
- hf_hub_download(
23
- repo_id="bartowski/gemma-2-27b-it-GGUF",
24
- filename="gemma-2-27b-it-Q5_K_M.gguf",
25
- local_dir="./models"
26
- )
27
 
28
- hf_hub_download(
29
- repo_id="google/gemma-2-2b-it-GGUF",
30
- filename="2b_it_v2.gguf",
31
- local_dir="./models",
32
- token=huggingface_token
 
33
  )
34
-
35
-
36
-
37
- llm = None
38
- llm_model = None
39
-
40
- @spaces.GPU(duration=120)
41
- def respond(
42
- message,
43
- history: list[tuple[str, str]],
44
- model,
45
- system_message,
46
- max_tokens,
47
- temperature,
48
- top_p,
49
- top_k,
50
- repeat_penalty,
51
- ):
52
- chat_template = MessagesFormatterType.GEMMA_2
53
-
54
- global llm
55
- global llm_model
56
-
57
- if llm is None or llm_model != model:
58
- llm = Llama(
59
- model_path=f"models/{model}",
60
- flash_attn=True,
61
- n_gpu_layers=81,
62
- n_batch=1024,
63
- n_ctx=8192,
64
  )
65
- llm_model = model
66
-
67
- provider = LlamaCppPythonProvider(llm)
68
-
69
- agent = LlamaCppAgent(
70
- provider,
71
- system_prompt=f"{system_message}",
72
- predefined_messages_formatter_type=chat_template,
73
- debug_output=True
 
 
 
 
 
 
 
 
 
 
74
  )
75
-
76
- settings = provider.get_provider_default_settings()
77
- settings.temperature = temperature
78
- settings.top_k = top_k
79
- settings.top_p = top_p
80
- settings.max_tokens = max_tokens
81
- settings.repeat_penalty = repeat_penalty
82
- settings.stream = True
83
-
84
- messages = BasicChatHistory()
85
 
86
- for msn in history:
87
- user = {
88
- 'role': Roles.user,
89
- 'content': msn[0]
90
- }
91
- assistant = {
92
- 'role': Roles.assistant,
93
- 'content': msn[1]
94
- }
95
- messages.add_message(user)
96
- messages.add_message(assistant)
97
-
98
- stream = agent.get_chat_response(
99
- message,
100
- llm_sampling_settings=settings,
101
- chat_history=messages,
102
- returns_streaming_generator=True,
103
- print_output=False
104
- )
105
-
106
- outputs = ""
107
- for output in stream:
108
- outputs += output
109
- yield outputs
110
 
111
- description = """<p align="center">Defaults to 2B (you can switch to 9B or 27B from additional inputs)</p>
112
- <p><center>
113
- <a href="https://huggingface.co/google/gemma-2-27b-it" target="_blank">[27B it Model]</a>
114
- <a href="https://huggingface.co/google/gemma-2-9b-it" target="_blank">[9B it Model]</a>
115
- <a href="https://huggingface.co/google/gemma-2-2b-it" target="_blank">[2B it Model]</a>
116
- <a href="https://huggingface.co/bartowski/gemma-2-27b-it-GGUF" target="_blank">[27B it Model GGUF]</a>
117
- <a href="https://huggingface.co/bartowski/gemma-2-9b-it-GGUF" target="_blank">[9B it Model GGUF]</a>
118
- <a href="https://huggingface.co/google/gemma-2-2b-it-GGUF" target="_blank">[2B it Model GGUF]</a>
119
- </center></p>
120
- """
121
 
122
- demo = gr.ChatInterface(
123
- respond,
124
  additional_inputs=[
125
- gr.Dropdown([
126
- 'gemma-2-9b-it-Q5_K_M.gguf',
127
- 'gemma-2-27b-it-Q5_K_M.gguf',
128
- '2b_it_v2.gguf'
129
- ],
130
- value="2b_it_v2.gguf",
131
- label="Model"
132
  ),
133
- gr.Textbox(value="You are a helpful assistant.", label="System message"),
134
- gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens"),
135
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
136
  gr.Slider(
 
137
  minimum=0.1,
 
 
 
 
 
 
 
138
  maximum=1.0,
139
- value=0.95,
140
  step=0.05,
141
- label="Top-p",
142
  ),
143
  gr.Slider(
144
- minimum=0,
145
- maximum=100,
146
- value=40,
147
- step=1,
148
  label="Top-k",
 
 
 
 
149
  ),
150
  gr.Slider(
151
- minimum=0.0,
152
- maximum=2.0,
153
- value=1.1,
154
- step=0.1,
155
  label="Repetition penalty",
 
 
 
 
156
  ),
157
  ],
158
- retry_btn="Retry",
159
- undo_btn="Undo",
160
- clear_btn="Clear",
161
- submit_btn="Send",
162
- title="Chat with Gemma 2 using llama.cpp",
163
- description=description,
164
- chatbot=gr.Chatbot(
165
- scale=1,
166
- likeable=False,
167
- show_copy_button=True
168
- )
169
  )
170
 
 
 
 
 
 
171
  if __name__ == "__main__":
172
- demo.launch()
 
 
1
  import os
2
+ from threading import Thread
3
+ from typing import Iterator
 
 
 
 
 
 
 
4
 
5
+ import gradio as gr
6
+ import spaces
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, GemmaTokenizerFast, TextIteratorStreamer
9
+
10
+ DESCRIPTION = """\
11
+ # Gemma 2 2B IT
12
+ Gemma 2 is Google's latest iteration of open LLMs.
13
+ This is a demo of [`google/gemma-2-2b-it`](https://huggingface.co/google/gemma-2-2b-it), fine-tuned for instruction following.
14
+ For more details, please check [our post](https://huggingface.co/blog/gemma2).
15
+ 👉 Looking for a larger and more powerful version? Try the 27B version in [HuggingChat](https://huggingface.co/chat/models/google/gemma-2-27b-it) and the 9B version in [this Space](https://huggingface.co/spaces/huggingface-projects/gemma-2-9b-it).
16
+ """
17
 
18
+ MAX_MAX_NEW_TOKENS = 2048
19
+ DEFAULT_MAX_NEW_TOKENS = 1024
20
+ MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
 
 
21
 
22
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
 
 
 
23
 
24
+ model_id = "google/gemma-2-2b-it"
25
+ tokenizer = GemmaTokenizerFast.from_pretrained(model_id)
26
+ model = AutoModelForCausalLM.from_pretrained(
27
+ model_id,
28
+ device_map="auto",
29
+ torch_dtype=torch.bfloat16,
30
  )
31
+ model.config.sliding_window = 4096
32
+ model.eval()
33
+
34
+
35
+ @spaces.GPU(duration=90)
36
+ def generate(
37
+ message: str,
38
+ chat_history: list[tuple[str, str]],
39
+ max_new_tokens: int = 1024,
40
+ temperature: float = 0.6,
41
+ top_p: float = 0.9,
42
+ top_k: int = 50,
43
+ repetition_penalty: float = 1.2,
44
+ ) -> Iterator[str]:
45
+ conversation = []
46
+ for user, assistant in chat_history:
47
+ conversation.extend(
48
+ [
49
+ {"role": "user", "content": user},
50
+ {"role": "assistant", "content": assistant},
51
+ ]
 
 
 
 
 
 
 
 
 
52
  )
53
+ conversation.append({"role": "user", "content": message})
54
+
55
+ input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
56
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
57
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
58
+ gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
59
+ input_ids = input_ids.to(model.device)
60
+
61
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
62
+ generate_kwargs = dict(
63
+ {"input_ids": input_ids},
64
+ streamer=streamer,
65
+ max_new_tokens=max_new_tokens,
66
+ do_sample=True,
67
+ top_p=top_p,
68
+ top_k=top_k,
69
+ temperature=temperature,
70
+ num_beams=1,
71
+ repetition_penalty=repetition_penalty,
72
  )
73
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
74
+ t.start()
 
 
 
 
 
 
 
 
75
 
76
+ outputs = []
77
+ for text in streamer:
78
+ outputs.append(text)
79
+ yield "".join(outputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
 
 
 
 
 
 
 
 
 
 
81
 
82
+ chat_interface = gr.ChatInterface(
83
+ fn=generate,
84
  additional_inputs=[
85
+ gr.Slider(
86
+ label="Max new tokens",
87
+ minimum=1,
88
+ maximum=MAX_MAX_NEW_TOKENS,
89
+ step=1,
90
+ value=DEFAULT_MAX_NEW_TOKENS,
 
91
  ),
 
 
 
92
  gr.Slider(
93
+ label="Temperature",
94
  minimum=0.1,
95
+ maximum=4.0,
96
+ step=0.1,
97
+ value=0.6,
98
+ ),
99
+ gr.Slider(
100
+ label="Top-p (nucleus sampling)",
101
+ minimum=0.05,
102
  maximum=1.0,
 
103
  step=0.05,
104
+ value=0.9,
105
  ),
106
  gr.Slider(
 
 
 
 
107
  label="Top-k",
108
+ minimum=1,
109
+ maximum=1000,
110
+ step=1,
111
+ value=50,
112
  ),
113
  gr.Slider(
 
 
 
 
114
  label="Repetition penalty",
115
+ minimum=1.0,
116
+ maximum=2.0,
117
+ step=0.05,
118
+ value=1.2,
119
  ),
120
  ],
121
+ stop_btn=None,
122
+ examples=[
123
+ ["Hello there! How are you doing?"],
124
+ ["Can you explain briefly to me what is the Python programming language?"],
125
+ ["Explain the plot of Cinderella in a sentence."],
126
+ ["How many hours does it take a man to eat a Helicopter?"],
127
+ ["Write a 100-word article on 'Benefits of Open-Source in AI research'"],
128
+ ],
129
+ cache_examples=False,
 
 
130
  )
131
 
132
+ with gr.Blocks(css="style.css", fill_height=True) as demo:
133
+ gr.Markdown(DESCRIPTION)
134
+ gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
135
+ chat_interface.render()
136
+
137
  if __name__ == "__main__":
138
+ demo.queue(max_size=20).launch()