jfelipenc commited on
Commit
b16001c
โ€ข
1 Parent(s): 3c56fd6

New version using endpoints of Falcon and Mistral

Browse files
Files changed (1) hide show
  1. app.py +53 -7
app.py CHANGED
@@ -2,10 +2,30 @@ import random
2
  import time
3
  import torch
4
  import gradio as gr
 
5
  from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
6
  from peft import PeftModel, PeftConfig
7
  from textwrap import wrap, fill
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  # Functions to Wrap the Prompt Correctly
10
  def wrap_text(text, width=90):
11
  lines = text.split('\n')
@@ -65,8 +85,14 @@ class ChatbotInterface():
65
 
66
  self.submit.click(self.respond, [self.msg, self.chatbot], [self.msg, self.chatbot])
67
 
 
 
 
 
 
 
 
68
  def respond(self, msg, history):
69
- #bot_message = random.choice(["Hello, I'm MedChat! How can I help you?", "Hello there! I'm Medchat, a medical assistant! How can I help you?"])
70
  formatted_input = f"{{{{ {self.system_prompt} }}}}\nUser: {msg}\n{self.name}:"
71
  input_ids = tokenizer.encode(
72
  formatted_input,
@@ -90,6 +116,26 @@ class ChatbotInterface():
90
 
91
  return "", self.chat_history
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  if __name__ == "__main__":
94
  # Define the device
95
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -120,15 +166,15 @@ if __name__ == "__main__":
120
  with gr.Row() as row:
121
  with gr.Column() as col1:
122
  with gr.Tab("GaiaMinimed") as gaia:
123
- gaia_bot = ChatbotInterface("GaiaMinimed")
124
  with gr.Column() as col2:
125
  with gr.Tab("MistralMed") as mistral:
126
- mistral_bot = ChatbotInterface("MistralMed")
127
  with gr.Tab("Falcon-7B") as falcon7b:
128
- falcon_bot = ChatbotInterface("Falcon-7B")
129
 
130
- gaia_bot.msg.change(fn=lambda s: (s[::1], s[::1]), inputs=gaia_bot.msg, outputs=[mistral_bot.msg, falcon_bot.msg])
131
- mistral_bot.msg.change(fn=lambda s: (s[::1], s[::1]), inputs=mistral_bot.msg, outputs=[gaia_bot.msg, falcon_bot.msg])
132
- falcon_bot.msg.change(fn=lambda s: (s[::1], s[::1]), inputs=falcon_bot.msg, outputs=[gaia_bot.msg, mistral_bot.msg])
133
 
134
  demo.launch()
 
2
  import time
3
  import torch
4
  import gradio as gr
5
+ import requests
6
  from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
7
  from peft import PeftModel, PeftConfig
8
  from textwrap import wrap, fill
9
 
10
+ ## using Falcon 7b Instruct
11
+ Falcon_API_URL = "https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct"
12
+ HEADERS = {"Authorization": "Bearer <HF TOKEN>"}
13
+ def falcon_query(payload):
14
+ response = requests.post(Falcon_API_URL, headers=HEADERS, json=payload)
15
+ return response.json()
16
+ def falcon_inference(input_text):
17
+ payload = {"inputs": input_text}
18
+ return falcon_query(payload)
19
+
20
+ ## using Mistral
21
+ Mistral_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
22
+ def mistral_query(payload):
23
+ response = requests.post(Mistral_API_URL , headers=HEADERS, json=payload)
24
+ return response.json()
25
+ def mistral_inference(input_text):
26
+ payload = {"inputs": input_text}
27
+ return mistral_query(payload)
28
+
29
  # Functions to Wrap the Prompt Correctly
30
  def wrap_text(text, width=90):
31
  lines = text.split('\n')
 
85
 
86
  self.submit.click(self.respond, [self.msg, self.chatbot], [self.msg, self.chatbot])
87
 
88
+ def respond(self, msg, chatbot):
89
+ raise NotImplementedError
90
+
91
+ class GaiaMinimed(ChatbotInterface):
92
+ def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
93
+ super().__init__(name, system_prompt)
94
+
95
  def respond(self, msg, history):
 
96
  formatted_input = f"{{{{ {self.system_prompt} }}}}\nUser: {msg}\n{self.name}:"
97
  input_ids = tokenizer.encode(
98
  formatted_input,
 
116
 
117
  return "", self.chat_history
118
 
119
+ class FalconBot(ChatbotInterface):
120
+ def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
121
+ super().__init__(name, system_prompt)
122
+
123
+ def respond(self, msg, chatbot):
124
+ falcon_response = falcon_inference(msg)
125
+ falcon_output = falcon_response[0]["generated_text"]
126
+ self.chat_history.append([msg, falcon_output])
127
+ return "", falcon_output
128
+
129
+ class MistralBot(ChatbotInterface):
130
+ def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
131
+ super().__init__(name, system_prompt)
132
+
133
+ def respond(self, msg, chatbot):
134
+ mistral_response = mistral_inference(msg)
135
+ mistral_output = mistral_response[0]["generated_text"]
136
+ self.chat_history.append([msg, mistral_output])
137
+ return "", mistral_output
138
+
139
  if __name__ == "__main__":
140
  # Define the device
141
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
166
  with gr.Row() as row:
167
  with gr.Column() as col1:
168
  with gr.Tab("GaiaMinimed") as gaia:
169
+ gaia_bot = GaiaMinimed("GaiaMinimed")
170
  with gr.Column() as col2:
171
  with gr.Tab("MistralMed") as mistral:
172
+ mistral_bot = MistralBot("MistralMed")
173
  with gr.Tab("Falcon-7B") as falcon7b:
174
+ falcon_bot = FalconBot("Falcon-7B")
175
 
176
+ gaia_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=gaia_bot.msg, outputs=[mistral_bot.msg, falcon_bot.msg])
177
+ mistral_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=mistral_bot.msg, outputs=[gaia_bot.msg, falcon_bot.msg])
178
+ falcon_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=falcon_bot.msg, outputs=[gaia_bot.msg, mistral_bot.msg])
179
 
180
  demo.launch()