|
import gradio as gr |
|
from gradio.themes.utils import colors |
|
from t5 import T5 |
|
from koalpaca import KoAlpaca |
|
|
|
LOCAL_TEST = False |
|
MODELS = [] |
|
cur_index = 0 |
|
|
|
def prepare_theme(): |
|
theme = gr.themes.Default(primary_hue=colors.gray, |
|
secondary_hue=colors.emerald, |
|
neutral_hue=colors.emerald).set( |
|
body_background_fill="*primary_800", |
|
body_background_fill_dark="*primary_800", |
|
|
|
block_background_fill="*primary_700", |
|
block_background_fill_dark="*primary_700", |
|
|
|
border_color_primary="*secondary_300", |
|
border_color_primary_dark="*secondary_300", |
|
block_border_width="3px", |
|
input_border_width="2px", |
|
|
|
input_background_fill="*primary_700", |
|
input_background_fill_dark="*primary_700", |
|
|
|
background_fill_secondary="*primary_700", |
|
background_fill_secondary_dark="*primary_700", |
|
|
|
body_text_color="white", |
|
body_text_color_dark="white", |
|
|
|
block_label_text_color="*secondary_300", |
|
block_label_text_color_dark="*secondary_300", |
|
block_label_background_fill="*primary_800", |
|
block_label_background_fill_dark="*primary_800", |
|
|
|
color_accent_soft="*primary_600", |
|
color_accent_soft_dark="*primary_600", |
|
) |
|
return theme |
|
|
|
def chat(message, chat_history): |
|
response = MODELS[cur_index].generate(message) |
|
chat_history.append((message, response)) |
|
return "", chat_history |
|
|
|
def change_model_index(idx): |
|
global cur_index |
|
cur_index = idx |
|
print(cur_index) |
|
return |
|
|
|
if __name__=='__main__': |
|
theme = prepare_theme() |
|
|
|
MODELS.append(T5()) |
|
if not LOCAL_TEST: |
|
MODELS.append(KoAlpaca()) |
|
|
|
with gr.Blocks(theme=theme) as demo: |
|
with gr.Row(): |
|
rd = gr.Radio(['T5','KoAlpaca'], value='T5', type='index', label='Model Selection', show_label=True, interactive=True) |
|
with gr.Column(scale=5): |
|
chatbot = gr.Chatbot(label="T5", bubble_full_width=False) |
|
with gr.Row(): |
|
txt = gr.Textbox(show_label=False, placeholder='Send a message...', container=False) |
|
|
|
txt.submit(chat, [txt, chatbot], [txt, chatbot]) |
|
rd.select(change_model_index, [rd]) |
|
demo.launch(debug=True, share=True) |