Spaces:
Build error
Build error
Quyet
commited on
Commit
•
93152cc
1
Parent(s):
f30862a
fix global vs local chatbot
Browse files
README.md
CHANGED
@@ -18,6 +18,7 @@ For more information about this product, please visit this notion [page](https:/
|
|
18 |
|
19 |
### 2022/12/20
|
20 |
|
|
|
21 |
- Chat flow will trigger euc 200 when detect a negative emotion with prob > threshold. Thus, only euc 100 and free chat consist of chat loop, while euc 200 will pop up sometimes. I set the trigger to NOT be regularly (currently one trigger once during the conversation), because trigger to much will bother users
|
22 |
- Already fix the problem with dialog model. Now it's configured as the same as what it should be. Of course, that does not guarantee of good response
|
23 |
- TODO is written in the main file already
|
|
|
18 |
|
19 |
### 2022/12/20
|
20 |
|
21 |
+
- DONE turning the chatbot to session varible so that different sessions will show different conversation
|
22 |
- Chat flow will trigger euc 200 when detect a negative emotion with prob > threshold. Thus, only euc 100 and free chat consist of chat loop, while euc 200 will pop up sometimes. I set the trigger to NOT be regularly (currently one trigger once during the conversation), because trigger to much will bother users
|
23 |
- Already fix the problem with dialog model. Now it's configured as the same as what it should be. Of course, that does not guarantee of good response
|
24 |
- TODO is written in the main file already
|
app.py
CHANGED
@@ -8,9 +8,10 @@ reference:
|
|
8 |
|
9 |
gradio vs streamlit
|
10 |
https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
|
11 |
-
https://gradio.app/interface_state/
|
12 |
|
13 |
TODO
|
|
|
14 |
Add diagram in Gradio Interface showing sentimate analysis
|
15 |
Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement
|
16 |
Personalize: create database, load and save data
|
@@ -40,8 +41,21 @@ def option():
|
|
40 |
args = parser.parse_args()
|
41 |
return args
|
42 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
-
class ChatHelper: # store the list of messages that are showed in therapies
|
45 |
invalid_input = 'Invalid input, my friend :) Plz input again'
|
46 |
good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?'
|
47 |
good_case = 'Nice to hear that!'
|
@@ -130,28 +144,19 @@ class TherapyChatBot:
|
|
130 |
self.euc_100_emotion_degree = []
|
131 |
self.already_trigger_euc_200 = False
|
132 |
|
133 |
-
# chat and emotion-detection models
|
134 |
-
self.ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True)
|
135 |
-
self.ed_threshold = 0.3
|
136 |
-
self.dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model)
|
137 |
-
self.dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model)
|
138 |
-
self.eos = self.dialog_tokenizer.eos_token
|
139 |
-
# tokenizer.__call__ -> input_ids, attention_mask
|
140 |
-
# tokenizer.encode -> only inputs_ids, which is required by model.generate function
|
141 |
-
|
142 |
# chat history.
|
143 |
# TODO: if we want to personalize and save the conversation,
|
144 |
# we can load data from database
|
145 |
-
self.greeting = ChatHelper.greeting_template[self.chat_state]
|
146 |
-
self.history = {'input_ids': torch.tensor([[
|
147 |
-
'text':
|
148 |
if 'euc_100' in self.chat_state:
|
149 |
self.chat_state = 'euc_100.q.0'
|
150 |
|
151 |
def __call__(self, message, prefix=''):
|
152 |
# if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion
|
153 |
if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200:
|
154 |
-
prediction =
|
155 |
prediction = sorted(prediction, key=lambda x: x['score'], reverse=True)
|
156 |
if self.run_on_own_server:
|
157 |
print(prediction)
|
@@ -160,7 +165,7 @@ class TherapyChatBot:
|
|
160 |
|
161 |
# if message is negative, change state immediately
|
162 |
if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \
|
163 |
-
(emotion['label'] in ChatHelper.negative_emotions and emotion['score'] >
|
164 |
self.chat_state_prev = self.chat_state
|
165 |
self.chat_state = 'euc_200'
|
166 |
self.message_prev = message
|
@@ -171,7 +176,7 @@ class TherapyChatBot:
|
|
171 |
elif self.chat_state.startswith('euc_100'):
|
172 |
response = self.euc_100(message)
|
173 |
if self.chat_state == 'free_chat':
|
174 |
-
last_two_turns_ids =
|
175 |
self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1)
|
176 |
|
177 |
elif self.chat_state.startswith('euc_200'):
|
@@ -185,7 +190,6 @@ class TherapyChatBot:
|
|
185 |
self.history['text'].append((self.message_prev, response))
|
186 |
else:
|
187 |
self.history['text'].append((message, response))
|
188 |
-
return self.history['text']
|
189 |
|
190 |
def euc_100(self, x):
|
191 |
_, subsection, entry = self.chat_state.split('.')
|
@@ -251,23 +255,23 @@ class TherapyChatBot:
|
|
251 |
message = self.message_prev
|
252 |
self.message_prev = x
|
253 |
self.chat_state = self.chat_state_prev
|
254 |
-
return self.__call__(message, response)
|
255 |
|
256 |
def free_chat(self, message):
|
257 |
-
message_ids =
|
258 |
self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1)
|
259 |
input_ids = self.history['input_ids'].clone()
|
260 |
|
261 |
while True:
|
262 |
-
bot_output_ids =
|
263 |
do_sample=True, top_p=0.9, temperature=0.8, num_beams=2,
|
264 |
-
pad_token_id=
|
265 |
-
response =
|
266 |
skip_special_tokens=True)
|
267 |
if response.strip() != '':
|
268 |
break
|
269 |
-
elif input_ids[0].tolist().count(
|
270 |
-
idx = input_ids[0].tolist().index(
|
271 |
input_ids = input_ids[:, (idx+1):]
|
272 |
else:
|
273 |
input_ids = message_ids
|
@@ -282,20 +286,22 @@ class TherapyChatBot:
|
|
282 |
return response
|
283 |
|
284 |
|
285 |
-
if __name__ == '__main__':
|
286 |
-
|
287 |
-
|
|
|
|
|
288 |
|
289 |
title = 'PsyPlus Empathetic Chatbot'
|
290 |
description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
|
291 |
-
|
|
|
292 |
iface = gr.Interface(
|
293 |
-
chat, 'text', chatbot,
|
294 |
allow_flagging='never', title=title, description=description,
|
295 |
)
|
296 |
|
297 |
-
# iface.queue(concurrency_count=5)
|
298 |
if args.run_on_own_server == 0:
|
299 |
iface.launch(debug=True)
|
300 |
else:
|
301 |
-
iface.launch(debug=True, share=True)
|
|
|
8 |
|
9 |
gradio vs streamlit
|
10 |
https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
|
11 |
+
https://gradio.app/interface_state/ -> global and local varible affect the separation of sessions
|
12 |
|
13 |
TODO
|
14 |
+
Add command to reset/jump to a function, e.g >reset, >euc_100
|
15 |
Add diagram in Gradio Interface showing sentimate analysis
|
16 |
Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement
|
17 |
Personalize: create database, load and save data
|
|
|
41 |
args = parser.parse_args()
|
42 |
return args
|
43 |
|
44 |
+
args = option()
|
45 |
+
|
46 |
+
|
47 |
+
# store the list of messages that are showed in therapies and models as global variables
|
48 |
+
# let all chat-session-wise variables placed in TherapyChatBot
|
49 |
+
class ChatHelper:
|
50 |
+
# chat and emotion-detection models
|
51 |
+
ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True)
|
52 |
+
ed_threshold = 0.3
|
53 |
+
dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model)
|
54 |
+
dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model)
|
55 |
+
eos = dialog_tokenizer.eos_token
|
56 |
+
# tokenizer.__call__ -> input_ids, attention_mask
|
57 |
+
# tokenizer.encode -> only inputs_ids, which is required by model.generate function
|
58 |
|
|
|
59 |
invalid_input = 'Invalid input, my friend :) Plz input again'
|
60 |
good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?'
|
61 |
good_case = 'Nice to hear that!'
|
|
|
144 |
self.euc_100_emotion_degree = []
|
145 |
self.already_trigger_euc_200 = False
|
146 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
# chat history.
|
148 |
# TODO: if we want to personalize and save the conversation,
|
149 |
# we can load data from database
|
150 |
+
self.greeting = [('', ChatHelper.greeting_template[self.chat_state])]
|
151 |
+
self.history = {'input_ids': torch.tensor([[ChatHelper.dialog_tokenizer.bos_token_id]]),
|
152 |
+
'text': self.greeting} if not self.account else open(f'database/{hash(self.account)}', 'rb')
|
153 |
if 'euc_100' in self.chat_state:
|
154 |
self.chat_state = 'euc_100.q.0'
|
155 |
|
156 |
def __call__(self, message, prefix=''):
|
157 |
# if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion
|
158 |
if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200:
|
159 |
+
prediction = ChatHelper.ed_pipe(message)[0]
|
160 |
prediction = sorted(prediction, key=lambda x: x['score'], reverse=True)
|
161 |
if self.run_on_own_server:
|
162 |
print(prediction)
|
|
|
165 |
|
166 |
# if message is negative, change state immediately
|
167 |
if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \
|
168 |
+
(emotion['label'] in ChatHelper.negative_emotions and emotion['score'] > ChatHelper.ed_threshold):
|
169 |
self.chat_state_prev = self.chat_state
|
170 |
self.chat_state = 'euc_200'
|
171 |
self.message_prev = message
|
|
|
176 |
elif self.chat_state.startswith('euc_100'):
|
177 |
response = self.euc_100(message)
|
178 |
if self.chat_state == 'free_chat':
|
179 |
+
last_two_turns_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt')
|
180 |
self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1)
|
181 |
|
182 |
elif self.chat_state.startswith('euc_200'):
|
|
|
190 |
self.history['text'].append((self.message_prev, response))
|
191 |
else:
|
192 |
self.history['text'].append((message, response))
|
|
|
193 |
|
194 |
def euc_100(self, x):
|
195 |
_, subsection, entry = self.chat_state.split('.')
|
|
|
255 |
message = self.message_prev
|
256 |
self.message_prev = x
|
257 |
self.chat_state = self.chat_state_prev
|
258 |
+
return self.__call__(message, prefix=response)
|
259 |
|
260 |
def free_chat(self, message):
|
261 |
+
message_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt')
|
262 |
self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1)
|
263 |
input_ids = self.history['input_ids'].clone()
|
264 |
|
265 |
while True:
|
266 |
+
bot_output_ids = ChatHelper.dialog_model.generate(input_ids, max_length=1000,
|
267 |
do_sample=True, top_p=0.9, temperature=0.8, num_beams=2,
|
268 |
+
pad_token_id=ChatHelper.dialog_tokenizer.eos_token_id)
|
269 |
+
response = ChatHelper.dialog_tokenizer.decode(bot_output_ids[0][input_ids.shape[-1]:],
|
270 |
skip_special_tokens=True)
|
271 |
if response.strip() != '':
|
272 |
break
|
273 |
+
elif input_ids[0].tolist().count(ChatHelper.dialog_tokenizer.eos_token_id) > 0:
|
274 |
+
idx = input_ids[0].tolist().index(ChatHelper.dialog_tokenizer.eos_token_id)
|
275 |
input_ids = input_ids[:, (idx+1):]
|
276 |
else:
|
277 |
input_ids = message_ids
|
|
|
286 |
return response
|
287 |
|
288 |
|
289 |
+
if __name__ == '__main__':
|
290 |
+
def chat(message, bot):
|
291 |
+
bot = bot or TherapyChatBot(args)
|
292 |
+
bot(message)
|
293 |
+
return bot.history['text'], bot
|
294 |
|
295 |
title = 'PsyPlus Empathetic Chatbot'
|
296 |
description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
|
297 |
+
greeting = [('', ChatHelper.greeting_template[args.initial_chat_state])]
|
298 |
+
chatbot = gr.Chatbot(value=greeting)
|
299 |
iface = gr.Interface(
|
300 |
+
chat, ['text', 'state'], [chatbot, 'state'],
|
301 |
allow_flagging='never', title=title, description=description,
|
302 |
)
|
303 |
|
|
|
304 |
if args.run_on_own_server == 0:
|
305 |
iface.launch(debug=True)
|
306 |
else:
|
307 |
+
iface.launch(debug=True, share=True)
|