File size: 4,812 Bytes
238795e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import re

import gradio as gr
import requests


BACKEND_ADDR = "http://46.8.230.12:8585/post"

# Initialize an empty list to store user queries
queries_list = []

# Convert model name to the corresponding model id that backend understands
llm_name2id = {
    "Llama-3.1-70B-Versatile": "llama-3.1-70b-versatile",
    "Llama-3-70B-8192": "llama3-70b-8192",
    "Llama-3-8B-8192": "llama3-8b-8192"
}

# Some default values
DEFAULT_TEMP = 0.2
DEFAULT_MODEL = "Llama-3-70B-8192"
DEFAULT_USE_RAG = True


# Function to handle the 'Clear' button
def clear_queries():
    global queries_list

    queries_list = []
    return "", DEFAULT_MODEL, DEFAULT_TEMP, DEFAULT_USE_RAG, "", ""


# Function to handle the 'Add' button
def add_to_list(query_txt, model, temperature, use_rag):
    global queries_list

    if len(query_txt) > 0:
        queries_list.append(
            {
                "prompt": query_txt,
                "temperature": str(temperature),
                "model": llm_name2id[model],
                "use_rag": str(use_rag),
            }
        )

    return "", generate_queries_str(queries_list)


# Function to handle the 'Submit' button
def submit(query_txt, model, temperature, use_rag):
    global queries_list

    if len(query_txt) > 0:
        _, queries = add_to_list(query_txt, model, temperature, use_rag)
    else:
        queries = generate_queries_str(queries_list)

    if len(queries_list) > 0:
        response = requests.post(BACKEND_ADDR, json=queries_list)
        answers = generate_answers_str(response.json())

        # Re-initialize the user's query list
        queries_list = []
    else:
        answers = ""

    return "", queries, answers


# Helper function to generate string representation of user queries
def generate_queries_str(queries_list):
    delimiter = f"\n{'-' * 120}\n"
    queries_str = delimiter.join([q["prompt"] for q in queries_list])
    return queries_str


# Helper function to generate string representation of model answers
def generate_answers_str(answers_list):
    delimiter = f"\n{'-' * 120}\n"
    answers_str = delimiter.join([clean(a["answer"]) for a in answers_list])
    return answers_str


# Helper function to clean a model-generated answer
def clean(answer_str):
    answer_str = re.sub('^\s*:', '', answer_str)

    garbages = [
        "Here is the generated paragraph:",
        "Let me know if this meets your requirements!",
    ]
    for g in garbages:
        answer_str = answer_str.replace(g, "")
    answer_str = answer_str.strip()
    return answer_str


if __name__ == "__main__":
    # Gradio interface
    with gr.Blocks(theme=gr.themes.Default()) as demo:
        with gr.Row():
            with gr.Column(scale=2):
                query_txt = gr.Textbox(
                    placeholder="پرسش خود را این‌جا وارد کنید...",
                    label="Query", rtl=True)
            with gr.Column(scale=1):
                model = gr.Radio(
                    choices=[
                        "Llama-3-8B-8192",
                        "Llama-3-70B-8192",
                        "Llama-3.1-70B-Versatile",
                    ],
                    value=DEFAULT_MODEL,
                    label="LLM"
                )
                use_rag = gr.Checkbox(value=DEFAULT_USE_RAG, label="Use RAG?")
                temperature = gr.Slider(minimum=0, maximum=1,
                                        value=DEFAULT_TEMP,
                                        step=0.1, label="Temperature")

        with gr.Row():
            clear_btn = gr.Button("Clear", variant="stop")
            add_btn = gr.Button("Add", variant="secondary")
            submit_btn = gr.Button("Submit", variant="primary")

        with gr.Row():
            with gr.Column():
                queries_box = gr.Textbox(
                    placeholder="پرسش‌های شما این‌جا نمایش داده خواهد شد...",
                    label="Queries", interactive=False, rtl=True)
            with gr.Column():
                answers_box = gr.Textbox(
                    placeholder="پاسخ‌های مدل این‌جا نمایش داده خواهد شد...",
                    label="Answers", interactive=False, rtl=True)

        clear_btn.click(
            fn=clear_queries,
            inputs=[],
            outputs=[query_txt, model, temperature, use_rag,
                     queries_box, answers_box]
        )
        add_btn.click(
            fn=add_to_list,
            inputs=[query_txt, model, temperature, use_rag],
            outputs=[query_txt, queries_box]
        )
        submit_btn.click(
            fn=submit,
            inputs=[query_txt, model, temperature, use_rag],
            outputs=[query_txt, queries_box, answers_box]
        )

    demo.launch()