medical-chatbot / app.py
jfelipenc's picture
Update app.py (#2)
14e4c0c
import os
import torch
import gradio as gr
import requests
from transformers import AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM, MistralForCausalLM
from peft import PeftModel, PeftConfig
from textwrap import wrap, fill
## using Falcon 7b Instruct
Falcon_API_URL = "https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct"
hf_token = os.getenv("HUGGINGFACE_TOKEN")
HEADERS = {"Authorization": "Bearer {hf_token}"}
def falcon_query(payload):
response = requests.post(Falcon_API_URL, headers=HEADERS, json=payload)
return response.json()
def falcon_inference(input_text):
payload = {"inputs": input_text}
return falcon_query(payload)
## using Mistral
Mistral_API_URL = "https://api-inference.huggingface.co/models/mistralai/Mistral-7B-v0.1"
def mistral_query(payload):
response = requests.post(Mistral_API_URL , headers=HEADERS, json=payload)
return response.json()
def mistral_inference(input_text):
payload = {"inputs": input_text}
return mistral_query(payload)
# Functions to Wrap the Prompt Correctly
def wrap_text(text, width=90):
lines = text.split('\n')
wrapped_lines = [fill(line, width=width) for line in lines]
wrapped_text = '\n'.join(wrapped_lines)
return wrapped_text
class ChatbotInterface():
def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
self.name = name
self.system_prompt = system_prompt
self.chatbot = gr.Chatbot()
self.chat_history = []
with gr.Row() as row:
row.justify = "end"
self.msg = gr.Textbox(scale=7)
#self.msg.change(fn=, inputs=, outputs=)
self.submit = gr.Button("Submit", scale=1)
clear = gr.ClearButton([self.msg, self.chatbot])
chat_history = []
self.submit.click(self.respond, [self.msg, self.chatbot], [self.msg, self.chatbot])
def respond(self, msg, chatbot):
raise NotImplementedError
class GaiaMinimed(ChatbotInterface):
def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
super().__init__(name, system_prompt)
def respond(self, msg, history):
formatted_input = f"{{{{ {self.system_prompt} }}}}\nUser: {msg}\n{self.name}:"
input_ids = tokenizer.encode(
formatted_input,
return_tensors="pt",
add_special_tokens=False
)
response = peft_model.generate(
input_ids=input_ids,
max_length=500,
use_cache=False,
early_stopping=False,
bos_token_id=peft_model.config.bos_token_id,
eos_token_id=peft_model.config.eos_token_id,
pad_token_id=peft_model.config.eos_token_id,
temperature=0.4,
do_sample=True
)
response_text = tokenizer.decode(response[0], skip_special_tokens=True)
self.chat_history.append([formatted_input, response_text])
return "", self.chat_history
class FalconBot(ChatbotInterface):
def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
super().__init__(name, system_prompt)
def respond(self, msg, chatbot):
falcon_response = falcon_inference(msg)
falcon_output = falcon_response[0]["generated_text"]
self.chat_history.append([msg, falcon_output])
return "", falcon_output
class MistralBot(ChatbotInterface):
def __init__(self, name, system_prompt="You are an expert medical analyst that helps users with any medical related information."):
super().__init__(name, system_prompt)
def respond(self, msg, chatbot):
mistral_response = mistral_inference(msg)
mistral_output = mistral_response[0]["generated_text"]
self.chat_history.append([msg, mistral_output])
return "", mistral_output
if __name__ == "__main__":
# Define the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Use the base model's ID
base_model_id = "tiiuae/falcon-7b-instruct"
model_directory = "Tonic/GaiaMiniMed"
# Instantiate the Tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, padding_side="left")
# Specify the configuration class for the model
model_config = AutoConfig.from_pretrained(base_model_id)
# Load the PEFT model with the specified configuration
peft_model = AutoModelForCausalLM.from_pretrained(model_directory, config=model_config)
peft_model = PeftModel.from_pretrained(peft_model, model_directory)
with gr.Blocks() as demo:
with gr.Row() as intro:
gr.Markdown(
"""
# MedChat: Your Medical Assistant Chatbot
Welcome to MedChat, your friendly medical assistant chatbot! ๐Ÿฉบ
Dive into a world of medical expertise where you can interact with three specialized chatbots, all trained on the latest and most comprehensive medical dataset. Whether you have health-related questions, need medical advice, or just want to learn more about your well-being, MedChat is here to help!
## How it Works
Simply type your medical query or concern, and let MedChat's advanced algorithms provide you with accurate and reliable responses.
## Explore and Compare
Feel like experimenting? Click the **Submit to All** button and witness the magic as all three chatbots compete to provide you with the best possible answer! It's a unique opportunity to compare the insights from different models and choose the one that suits your needs the best.
_Ready to get started? Type your question and let's begin!_
"""
)
with gr.Row() as row:
with gr.Column() as col1:
with gr.Tab("GaiaMinimed") as gaia:
gaia_bot = GaiaMinimed("GaiaMinimed")
with gr.Column() as col2:
with gr.Tab("MistralMed") as mistral:
mistral_bot = MistralBot("MistralMed")
with gr.Tab("Falcon-7B") as falcon7b:
falcon_bot = FalconBot("Falcon-7B")
gaia_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=gaia_bot.msg, outputs=[mistral_bot.msg, falcon_bot.msg])
mistral_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=mistral_bot.msg, outputs=[gaia_bot.msg, falcon_bot.msg])
falcon_bot.msg.input(fn=lambda s: (s[::1], s[::1]), inputs=falcon_bot.msg, outputs=[gaia_bot.msg, mistral_bot.msg])
demo.launch()