cahya commited on
Commit
fcc6e50
1 Parent(s): 6e89a9e

fix the multi model

Browse files
Files changed (1) hide show
  1. app/app.py +25 -11
app/app.py CHANGED
@@ -12,14 +12,16 @@ import psutil
12
 
13
  MODELS = {
14
  "GPT-2 Small finetuned on Indonesian stories": {
15
- "name": "cahya/gpt2-small-indonesian-story"
 
16
  },
17
  "GPT-2 Medium finetuned on Indonesian stories": {
18
- "name": "cahya/gpt2-medium-indonesian-story"
 
19
  },
20
  }
21
 
22
- model = st.selectbox('Model',([
23
  'GPT-2 Small finetuned on Indonesian stories',
24
  'GPT-2 Medium finetuned on Indonesian stories']))
25
 
@@ -34,7 +36,7 @@ def get_generator(model_name: str):
34
  return text_generator
35
 
36
  @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
37
- def process(text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
38
  temperature: float = 1.0, max_time: float = 10.0, seed=42):
39
  # st.write("Cache miss: process")
40
  set_seed(seed)
@@ -56,6 +58,10 @@ st.markdown(
56
  session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
57
 
58
  ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
 
 
 
 
59
  prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
60
 
61
  # Update prompt
@@ -72,8 +78,11 @@ else:
72
  if session_state.prompt == "Custom":
73
  session_state.prompt_box = "Enter your text here"
74
  else:
 
 
 
75
  if session_state.prompt is not None and session_state.prompt_box is None:
76
- session_state.prompt_box = random.choice(PROMPT_LIST[session_state.prompt])
77
 
78
  session_state.text = st.text_area("Enter text", session_state.prompt_box)
79
 
@@ -96,7 +105,7 @@ do_sample = st.sidebar.checkbox(
96
  value=True
97
  )
98
 
99
- top_k = 25
100
  top_p = 0.95
101
 
102
  if do_sample:
@@ -115,14 +124,16 @@ seed = st.sidebar.number_input(
115
  help="The number used to initialize a pseudorandom number generator"
116
  )
117
 
118
-
119
- text_generator = get_generator()
 
120
  if st.button("Run"):
121
  with st.spinner(text="Getting results..."):
122
  memory = psutil.virtual_memory()
123
  st.subheader("Result")
124
  time_start = time.time()
125
- result = process(text=session_state.text, max_length=int(max_length),
 
126
  temperature=temperature, do_sample=do_sample,
127
  top_k=int(top_k), top_p=float(top_p), seed=seed)
128
  time_end = time.time()
@@ -133,8 +144,11 @@ if st.button("Run"):
133
  translation = translate(result, "en", "id")
134
  st.write(translation.replace("\n", " \n"))
135
  # st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
136
- st.write(f"*Memory: {memory.total/(1024*1024*1024):.2f}GB, used: {memory.percent}%*, available: {memory.available/(1024*1024*1024):.2f}GB")
137
- st.write(f"*Text generated in {time_diff:.5} seconds*")
 
 
 
138
 
139
  # Reset state
140
  session_state.prompt = None
 
12
 
13
  MODELS = {
14
  "GPT-2 Small finetuned on Indonesian stories": {
15
+ "name": "cahya/gpt2-small-indonesian-story",
16
+ "text_generator": None
17
  },
18
  "GPT-2 Medium finetuned on Indonesian stories": {
19
+ "name": "cahya/gpt2-medium-indonesian-story",
20
+ "text_generator": None
21
  },
22
  }
23
 
24
+ model = st.sidebar.selectbox('Model',([
25
  'GPT-2 Small finetuned on Indonesian stories',
26
  'GPT-2 Medium finetuned on Indonesian stories']))
27
 
 
36
  return text_generator
37
 
38
  @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
39
+ def process(text_generator, text: str, max_length: int = 100, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
40
  temperature: float = 1.0, max_time: float = 10.0, seed=42):
41
  # st.write("Cache miss: process")
42
  set_seed(seed)
 
58
  session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
59
 
60
  ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
61
+
62
+ print("# Prompt list", PROMPT_LIST)
63
+ print("# All Prompt", ALL_PROMPTS)
64
+
65
  prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
66
 
67
  # Update prompt
 
78
  if session_state.prompt == "Custom":
79
  session_state.prompt_box = "Enter your text here"
80
  else:
81
+ print(f"# prompt: {session_state.prompt}")
82
+ print(f"# prompt_box: {session_state.prompt_box}")
83
+ print(f"# PROMPT_LIST: {PROMPT_LIST.keys()}")
84
  if session_state.prompt is not None and session_state.prompt_box is None:
85
+ session_state.prompt_box = random.choice(PROMPT_LIST[prompt_group_name][session_state.prompt])
86
 
87
  session_state.text = st.text_area("Enter text", session_state.prompt_box)
88
 
 
105
  value=True
106
  )
107
 
108
+ top_k = 40
109
  top_p = 0.95
110
 
111
  if do_sample:
 
124
  help="The number used to initialize a pseudorandom number generator"
125
  )
126
 
127
+ for group_name in MODELS:
128
+ MODELS[group_name]["text_generator"] = get_generator(MODELS[group_name]["name"])
129
+ # text_generator = get_generator()
130
  if st.button("Run"):
131
  with st.spinner(text="Getting results..."):
132
  memory = psutil.virtual_memory()
133
  st.subheader("Result")
134
  time_start = time.time()
135
+ # text_generator = MODELS[model]["text_generator"]
136
+ result = process(MODELS[model]["text_generator"], text=session_state.text, max_length=int(max_length),
137
  temperature=temperature, do_sample=do_sample,
138
  top_k=int(top_k), top_p=float(top_p), seed=seed)
139
  time_end = time.time()
 
144
  translation = translate(result, "en", "id")
145
  st.write(translation.replace("\n", " \n"))
146
  # st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
147
+ info = f"""
148
+ *Memory: {memory.total/(1024*1024*1024):.2f}GB, used: {memory.percent}%, available: {memory.available/(1024*1024*1024):.2f}GB*
149
+ *Text generated in {time_diff:.5} seconds*
150
+ """
151
+ st.write(info)
152
 
153
  # Reset state
154
  session_state.prompt = None