New version using endpoints of Falcon and Mistral
Browse files
app.py
CHANGED
@@ -2,10 +2,30 @@ import random
|
|
2 |
import time
|
3 |
import torch
|
4 |
import gradio as gr
|
|
|
5 |
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
|
6 |
from peft import PeftModel, PeftConfig
|
7 |
from textwrap import wrap, fill
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
# Functions to Wrap the Prompt Correctly
|
10 |
def wrap_text(text, width=90):
|
11 |
lines = text.split('\n')
|
@@ -65,8 +85,14 @@ class ChatbotInterface():
|
|
65 |
|
66 |
self.submit.click(self.respond, [self.msg, self.chatbot], [self.msg, self.chatbot])
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
def respond(self, msg, history):
|
69 |
-
#bot_message = random.choice(["Hello, I'm MedChat! How can I help you?", "Hello there! I'm Medchat, a medical assistant! How can I help you?"])
|
70 |
formatted_input = f"{{{{ {self.system_prompt} }}}}\nUser: {msg}\n{self.name}:"
|
71 |
input_ids = tokenizer.encode(
|
72 |
formatted_input,
|
@@ -90,6 +116,26 @@ class ChatbotInterface():
|
|
90 |
|
91 |
return "", self.chat_history
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
if __name__ == "__main__":
|
94 |
# Define the device
|
95 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -120,15 +166,15 @@ if __name__ == "__main__":
|
|
120 |
with gr.Row() as row:
|
121 |
with gr.Column() as col1:
|
122 |
with gr.Tab("GaiaMinimed") as gaia:
|
123 |
-
gaia_bot =
|
124 |
with gr.Column() as col2:
|
125 |
with gr.Tab("MistralMed") as mistral:
|
126 |
-
mistral_bot =
|
127 |
with gr.Tab("Falcon-7B") as falcon7b:
|
128 |
-
falcon_bot =
|
129 |
|
130 |
-
gaia_bot.msg.
|
131 |
-
mistral_bot.msg.
|
132 |
-
falcon_bot.msg.
|
133 |
|
134 |
demo.launch()
|
|
|
2 |
import time
|
3 |
import torch
|
4 |
import gradio as gr
|
5 |
+
import requests
|
6 |
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
|
7 |
from peft import PeftModel, PeftConfig
|
8 |
from textwrap import wrap, fill
|
9 |
|
10 |
+
## using Falcon 7b Instruct
|
11 |
+
Falcon_API_URL = "https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct"
|
12 |
+
HEADERS = {"Authorization": "Bearer <HF TOKEN>"}
|
13 |
+
def falcon_query(payload):
|
14 |
+
response = requests.post(Falcon_API_URL, headers=HEADERS, json=payload)
|
15 |
+
return response.json()
|
16 |
+
def falcon_inference(input_text):
|
17 |
+
payload = {"inputs": input_text}
|
18 |
+
return falcon_query(payload)
|
19 |
+
|
20 |
+
## using Mistral
|
21 |
+
Mistral_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
|
22 |
+
def mistral_query(payload):
|
23 |
+
response = requests.post(Mistral_API_URL , headers=HEADERS, json=payload)
|
24 |
+
return response.json()
|
25 |
+
def mistral_inference(input_text):
|
26 |
+
payload = {"inputs": input_text}
|
27 |
+
return mistral_query(payload)
|
28 |
+
|
29 |
# Functions to Wrap the Prompt Correctly
|
30 |
def wrap_text(text, width=90):
|
31 |
lines = text.split('\n')
|
|
|
85 |
|
86 |
self.submit.click(self.respond, [self.msg, self.chatbot], [self.msg, self.chatbot])
|
87 |
|
88 |
+
def respond(self, msg, chatbot):
|
89 |
+
raise NotImplementedError
|
90 |
+
|
91 |
+
class GaiaMinimed(ChatbotInterface):
|
92 |
+
def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
|
93 |
+
super().__init__(name, system_prompt)
|
94 |
+
|
95 |
def respond(self, msg, history):
|
|
|
96 |
formatted_input = f"{{{{ {self.system_prompt} }}}}\nUser: {msg}\n{self.name}:"
|
97 |
input_ids = tokenizer.encode(
|
98 |
formatted_input,
|
|
|
116 |
|
117 |
return "", self.chat_history
|
118 |
|
119 |
+
class FalconBot(ChatbotInterface):
|
120 |
+
def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
|
121 |
+
super().__init__(name, system_prompt)
|
122 |
+
|
123 |
+
def respond(self, msg, chatbot):
|
124 |
+
falcon_response = falcon_inference(msg)
|
125 |
+
falcon_output = falcon_response[0]["generated_text"]
|
126 |
+
self.chat_history.append([msg, falcon_output])
|
127 |
+
return "", falcon_output
|
128 |
+
|
129 |
+
class MistralBot(ChatbotInterface):
|
130 |
+
def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
|
131 |
+
super().__init__(name, system_prompt)
|
132 |
+
|
133 |
+
def respond(self, msg, chatbot):
|
134 |
+
mistral_response = mistral_inference(msg)
|
135 |
+
mistral_output = mistral_response[0]["generated_text"]
|
136 |
+
self.chat_history.append([msg, mistral_output])
|
137 |
+
return "", mistral_output
|
138 |
+
|
139 |
if __name__ == "__main__":
|
140 |
# Define the device
|
141 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
166 |
with gr.Row() as row:
|
167 |
with gr.Column() as col1:
|
168 |
with gr.Tab("GaiaMinimed") as gaia:
|
169 |
+
gaia_bot = GaiaMinimed("GaiaMinimed")
|
170 |
with gr.Column() as col2:
|
171 |
with gr.Tab("MistralMed") as mistral:
|
172 |
+
mistral_bot = MistralBot("MistralMed")
|
173 |
with gr.Tab("Falcon-7B") as falcon7b:
|
174 |
+
falcon_bot = FalconBot("Falcon-7B")
|
175 |
|
176 |
+
gaia_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=gaia_bot.msg, outputs=[mistral_bot.msg, falcon_bot.msg])
|
177 |
+
mistral_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=mistral_bot.msg, outputs=[gaia_bot.msg, falcon_bot.msg])
|
178 |
+
falcon_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=falcon_bot.msg, outputs=[gaia_bot.msg, mistral_bot.msg])
|
179 |
|
180 |
demo.launch()
|