fix the multi model
Browse files- app/app.py +25 -11
app/app.py
CHANGED
@@ -12,14 +12,16 @@ import psutil
|
|
12 |
|
13 |
MODELS = {
|
14 |
"GPT-2 Small finetuned on Indonesian stories": {
|
15 |
-
"name": "cahya/gpt2-small-indonesian-story"
|
|
|
16 |
},
|
17 |
"GPT-2 Medium finetuned on Indonesian stories": {
|
18 |
-
"name": "cahya/gpt2-medium-indonesian-story"
|
|
|
19 |
},
|
20 |
}
|
21 |
|
22 |
-
model = st.selectbox('Model',([
|
23 |
'GPT-2 Small finetuned on Indonesian stories',
|
24 |
'GPT-2 Medium finetuned on Indonesian stories']))
|
25 |
|
@@ -34,7 +36,7 @@ def get_generator(model_name: str):
|
|
34 |
return text_generator
|
35 |
|
36 |
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
|
37 |
-
def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
38 |
temperature: float = 1.0, max_time: float = 10.0, seed=42):
|
39 |
# st.write("Cache miss: process")
|
40 |
set_seed(seed)
|
@@ -56,6 +58,10 @@ st.markdown(
|
|
56 |
session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
|
57 |
|
58 |
ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
|
|
|
|
|
|
|
|
|
59 |
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
|
60 |
|
61 |
# Update prompt
|
@@ -72,8 +78,11 @@ else:
|
|
72 |
if session_state.prompt == "Custom":
|
73 |
session_state.prompt_box = "Enter your text here"
|
74 |
else:
|
|
|
|
|
|
|
75 |
if session_state.prompt is not None and session_state.prompt_box is None:
|
76 |
-
session_state.prompt_box = random.choice(PROMPT_LIST[session_state.prompt])
|
77 |
|
78 |
session_state.text = st.text_area("Enter text", session_state.prompt_box)
|
79 |
|
@@ -96,7 +105,7 @@ do_sample = st.sidebar.checkbox(
|
|
96 |
value=True
|
97 |
)
|
98 |
|
99 |
-
top_k =
|
100 |
top_p = 0.95
|
101 |
|
102 |
if do_sample:
|
@@ -115,14 +124,16 @@ seed = st.sidebar.number_input(
|
|
115 |
help="The number used to initialize a pseudorandom number generator"
|
116 |
)
|
117 |
|
118 |
-
|
119 |
-
text_generator = get_generator()
|
|
|
120 |
if st.button("Run"):
|
121 |
with st.spinner(text="Getting results..."):
|
122 |
memory = psutil.virtual_memory()
|
123 |
st.subheader("Result")
|
124 |
time_start = time.time()
|
125 |
-
|
|
|
126 |
temperature=temperature, do_sample=do_sample,
|
127 |
top_k=int(top_k), top_p=float(top_p), seed=seed)
|
128 |
time_end = time.time()
|
@@ -133,8 +144,11 @@ if st.button("Run"):
|
|
133 |
translation = translate(result, "en", "id")
|
134 |
st.write(translation.replace("\n", " \n"))
|
135 |
# st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
138 |
|
139 |
# Reset state
|
140 |
session_state.prompt = None
|
|
|
12 |
|
13 |
MODELS = {
|
14 |
"GPT-2 Small finetuned on Indonesian stories": {
|
15 |
+
"name": "cahya/gpt2-small-indonesian-story",
|
16 |
+
"text_generator": None
|
17 |
},
|
18 |
"GPT-2 Medium finetuned on Indonesian stories": {
|
19 |
+
"name": "cahya/gpt2-medium-indonesian-story",
|
20 |
+
"text_generator": None
|
21 |
},
|
22 |
}
|
23 |
|
24 |
+
model = st.sidebar.selectbox('Model',([
|
25 |
'GPT-2 Small finetuned on Indonesian stories',
|
26 |
'GPT-2 Medium finetuned on Indonesian stories']))
|
27 |
|
|
|
36 |
return text_generator
|
37 |
|
38 |
@st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
|
39 |
+
def process(text_generator, text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
40 |
temperature: float = 1.0, max_time: float = 10.0, seed=42):
|
41 |
# st.write("Cache miss: process")
|
42 |
set_seed(seed)
|
|
|
58 |
session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
|
59 |
|
60 |
ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
|
61 |
+
|
62 |
+
print("# Prompt list", PROMPT_LIST)
|
63 |
+
print("# All Prompt", ALL_PROMPTS)
|
64 |
+
|
65 |
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
|
66 |
|
67 |
# Update prompt
|
|
|
78 |
if session_state.prompt == "Custom":
|
79 |
session_state.prompt_box = "Enter your text here"
|
80 |
else:
|
81 |
+
print(f"# prompt: {session_state.prompt}")
|
82 |
+
print(f"# prompt_box: {session_state.prompt_box}")
|
83 |
+
print(f"# PROMPT_LIST: {PROMPT_LIST.keys()}")
|
84 |
if session_state.prompt is not None and session_state.prompt_box is None:
|
85 |
+
session_state.prompt_box = random.choice(PROMPT_LIST[prompt_group_name][session_state.prompt])
|
86 |
|
87 |
session_state.text = st.text_area("Enter text", session_state.prompt_box)
|
88 |
|
|
|
105 |
value=True
|
106 |
)
|
107 |
|
108 |
+
top_k = 40
|
109 |
top_p = 0.95
|
110 |
|
111 |
if do_sample:
|
|
|
124 |
help="The number used to initialize a pseudorandom number generator"
|
125 |
)
|
126 |
|
127 |
+
for group_name in MODELS:
|
128 |
+
MODELS[group_name]["text_generator"] = get_generator(MODELS[group_name]["name"])
|
129 |
+
# text_generator = get_generator()
|
130 |
if st.button("Run"):
|
131 |
with st.spinner(text="Getting results..."):
|
132 |
memory = psutil.virtual_memory()
|
133 |
st.subheader("Result")
|
134 |
time_start = time.time()
|
135 |
+
# text_generator = MODELS[model]["text_generator"]
|
136 |
+
result = process(MODELS[model]["text_generator"], text=session_state.text, max_length=int(max_length),
|
137 |
temperature=temperature, do_sample=do_sample,
|
138 |
top_k=int(top_k), top_p=float(top_p), seed=seed)
|
139 |
time_end = time.time()
|
|
|
144 |
translation = translate(result, "en", "id")
|
145 |
st.write(translation.replace("\n", " \n"))
|
146 |
# st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
|
147 |
+
info = f"""
|
148 |
+
*Memory: {memory.total/(1024*1024*1024):.2f}GB, used: {memory.percent}%, available: {memory.available/(1024*1024*1024):.2f}GB*
|
149 |
+
*Text generated in {time_diff:.5} seconds*
|
150 |
+
"""
|
151 |
+
st.write(info)
|
152 |
|
153 |
# Reset state
|
154 |
session_state.prompt = None
|