Quyet commited on
Commit
93152cc
1 Parent(s): f30862a

fix global vs local chatbot

Browse files
Files changed (2) hide show
  1. README.md +1 -0
  2. app.py +38 -32
README.md CHANGED
@@ -18,6 +18,7 @@ For more information about this product, please visit this notion [page](https:/
18
 
19
  ### 2022/12/20
20
 
 
21
  - Chat flow will trigger euc 200 when detect a negative emotion with prob > threshold. Thus, only euc 100 and free chat consist of chat loop, while euc 200 will pop up sometimes. I set the trigger to NOT be regularly (currently one trigger once during the conversation), because trigger to much will bother users
22
  - Already fix the problem with dialog model. Now it's configured as the same as what it should be. Of course, that does not guarantee of good response
23
  - TODO is written in the main file already
 
18
 
19
  ### 2022/12/20
20
 
21
+ - DONE turning the chatbot to session varible so that different sessions will show different conversation
22
  - Chat flow will trigger euc 200 when detect a negative emotion with prob > threshold. Thus, only euc 100 and free chat consist of chat loop, while euc 200 will pop up sometimes. I set the trigger to NOT be regularly (currently one trigger once during the conversation), because trigger to much will bother users
23
  - Already fix the problem with dialog model. Now it's configured as the same as what it should be. Of course, that does not guarantee of good response
24
  - TODO is written in the main file already
app.py CHANGED
@@ -8,9 +8,10 @@ reference:
8
 
9
  gradio vs streamlit
10
  https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
11
- https://gradio.app/interface_state/
12
 
13
  TODO
 
14
  Add diagram in Gradio Interface showing sentimate analysis
15
  Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement
16
  Personalize: create database, load and save data
@@ -40,8 +41,21 @@ def option():
40
  args = parser.parse_args()
41
  return args
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- class ChatHelper: # store the list of messages that are showed in therapies
45
  invalid_input = 'Invalid input, my friend :) Plz input again'
46
  good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?'
47
  good_case = 'Nice to hear that!'
@@ -130,28 +144,19 @@ class TherapyChatBot:
130
  self.euc_100_emotion_degree = []
131
  self.already_trigger_euc_200 = False
132
 
133
- # chat and emotion-detection models
134
- self.ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True)
135
- self.ed_threshold = 0.3
136
- self.dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model)
137
- self.dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model)
138
- self.eos = self.dialog_tokenizer.eos_token
139
- # tokenizer.__call__ -> input_ids, attention_mask
140
- # tokenizer.encode -> only inputs_ids, which is required by model.generate function
141
-
142
  # chat history.
143
  # TODO: if we want to personalize and save the conversation,
144
  # we can load data from database
145
- self.greeting = ChatHelper.greeting_template[self.chat_state]
146
- self.history = {'input_ids': torch.tensor([[self.dialog_tokenizer.bos_token_id]]),
147
- 'text': [('', self.greeting)]} if not self.account else open(f'database/{hash(self.account)}', 'rb')
148
  if 'euc_100' in self.chat_state:
149
  self.chat_state = 'euc_100.q.0'
150
 
151
  def __call__(self, message, prefix=''):
152
  # if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion
153
  if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200:
154
- prediction = self.ed_pipe(message)[0]
155
  prediction = sorted(prediction, key=lambda x: x['score'], reverse=True)
156
  if self.run_on_own_server:
157
  print(prediction)
@@ -160,7 +165,7 @@ class TherapyChatBot:
160
 
161
  # if message is negative, change state immediately
162
  if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \
163
- (emotion['label'] in ChatHelper.negative_emotions and emotion['score'] > self.ed_threshold):
164
  self.chat_state_prev = self.chat_state
165
  self.chat_state = 'euc_200'
166
  self.message_prev = message
@@ -171,7 +176,7 @@ class TherapyChatBot:
171
  elif self.chat_state.startswith('euc_100'):
172
  response = self.euc_100(message)
173
  if self.chat_state == 'free_chat':
174
- last_two_turns_ids = self.dialog_tokenizer.encode(message + self.eos, return_tensors='pt')
175
  self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1)
176
 
177
  elif self.chat_state.startswith('euc_200'):
@@ -185,7 +190,6 @@ class TherapyChatBot:
185
  self.history['text'].append((self.message_prev, response))
186
  else:
187
  self.history['text'].append((message, response))
188
- return self.history['text']
189
 
190
  def euc_100(self, x):
191
  _, subsection, entry = self.chat_state.split('.')
@@ -251,23 +255,23 @@ class TherapyChatBot:
251
  message = self.message_prev
252
  self.message_prev = x
253
  self.chat_state = self.chat_state_prev
254
- return self.__call__(message, response)
255
 
256
  def free_chat(self, message):
257
- message_ids = self.dialog_tokenizer.encode(message + self.eos, return_tensors='pt')
258
  self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1)
259
  input_ids = self.history['input_ids'].clone()
260
 
261
  while True:
262
- bot_output_ids = self.dialog_model.generate(input_ids, max_length=1000,
263
  do_sample=True, top_p=0.9, temperature=0.8, num_beams=2,
264
- pad_token_id=self.dialog_tokenizer.eos_token_id)
265
- response = self.dialog_tokenizer.decode(bot_output_ids[0][input_ids.shape[-1]:],
266
  skip_special_tokens=True)
267
  if response.strip() != '':
268
  break
269
- elif input_ids[0].tolist().count(self.dialog_tokenizer.eos_token_id) > 0:
270
- idx = input_ids[0].tolist().index(self.dialog_tokenizer.eos_token_id)
271
  input_ids = input_ids[:, (idx+1):]
272
  else:
273
  input_ids = message_ids
@@ -282,20 +286,22 @@ class TherapyChatBot:
282
  return response
283
 
284
 
285
- if __name__ == '__main__':
286
- args = option()
287
- chat = TherapyChatBot(args)
 
 
288
 
289
  title = 'PsyPlus Empathetic Chatbot'
290
  description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
291
- chatbot = gr.Chatbot(value=chat.history['text'])
 
292
  iface = gr.Interface(
293
- chat, 'text', chatbot,
294
  allow_flagging='never', title=title, description=description,
295
  )
296
 
297
- # iface.queue(concurrency_count=5)
298
  if args.run_on_own_server == 0:
299
  iface.launch(debug=True)
300
  else:
301
- iface.launch(debug=True, share=True) # server_name='0.0.0.0', server_port=2022
 
8
 
9
  gradio vs streamlit
10
  https://trojrobert.github.io/a-guide-for-deploying-and-serving-machine-learning-with-model-streamlit-vs-gradio/
11
+ https://gradio.app/interface_state/ -> global and local varible affect the separation of sessions
12
 
13
  TODO
14
+ Add command to reset/jump to a function, e.g >reset, >euc_100
15
  Add diagram in Gradio Interface showing sentimate analysis
16
  Gradio input timeout: cannot find a tutorial in Google -> don't know how to implement
17
  Personalize: create database, load and save data
 
41
  args = parser.parse_args()
42
  return args
43
 
44
+ args = option()
45
+
46
+
47
+ # store the list of messages that are showed in therapies and models as global variables
48
+ # let all chat-session-wise variables placed in TherapyChatBot
49
+ class ChatHelper:
50
+ # chat and emotion-detection models
51
+ ed_pipe = pipeline('text-classification', model=args.emotion_model, top_k=5, truncation=True)
52
+ ed_threshold = 0.3
53
+ dialog_model = GPT2LMHeadModel.from_pretrained(args.dialog_model)
54
+ dialog_tokenizer = GPT2Tokenizer.from_pretrained(args.dialog_model)
55
+ eos = dialog_tokenizer.eos_token
56
+ # tokenizer.__call__ -> input_ids, attention_mask
57
+ # tokenizer.encode -> only inputs_ids, which is required by model.generate function
58
 
 
59
  invalid_input = 'Invalid input, my friend :) Plz input again'
60
  good_mood_over = 'Whether your good mood is over? Any other details that you would like to recall?'
61
  good_case = 'Nice to hear that!'
 
144
  self.euc_100_emotion_degree = []
145
  self.already_trigger_euc_200 = False
146
 
 
 
 
 
 
 
 
 
 
147
  # chat history.
148
  # TODO: if we want to personalize and save the conversation,
149
  # we can load data from database
150
+ self.greeting = [('', ChatHelper.greeting_template[self.chat_state])]
151
+ self.history = {'input_ids': torch.tensor([[ChatHelper.dialog_tokenizer.bos_token_id]]),
152
+ 'text': self.greeting} if not self.account else open(f'database/{hash(self.account)}', 'rb')
153
  if 'euc_100' in self.chat_state:
154
  self.chat_state = 'euc_100.q.0'
155
 
156
  def __call__(self, message, prefix=''):
157
  # if prefix != None, which means this function is called from euc_200, thus already detected the negative emotion
158
  if (not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200:
159
+ prediction = ChatHelper.ed_pipe(message)[0]
160
  prediction = sorted(prediction, key=lambda x: x['score'], reverse=True)
161
  if self.run_on_own_server:
162
  print(prediction)
 
165
 
166
  # if message is negative, change state immediately
167
  if ((not prefix) and self.chat_state != 'euc_200' and not self.already_trigger_euc_200) and \
168
+ (emotion['label'] in ChatHelper.negative_emotions and emotion['score'] > ChatHelper.ed_threshold):
169
  self.chat_state_prev = self.chat_state
170
  self.chat_state = 'euc_200'
171
  self.message_prev = message
 
176
  elif self.chat_state.startswith('euc_100'):
177
  response = self.euc_100(message)
178
  if self.chat_state == 'free_chat':
179
+ last_two_turns_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt')
180
  self.history['input_ids'] = torch.cat([self.history['input_ids'], last_two_turns_ids], dim=-1)
181
 
182
  elif self.chat_state.startswith('euc_200'):
 
190
  self.history['text'].append((self.message_prev, response))
191
  else:
192
  self.history['text'].append((message, response))
 
193
 
194
  def euc_100(self, x):
195
  _, subsection, entry = self.chat_state.split('.')
 
255
  message = self.message_prev
256
  self.message_prev = x
257
  self.chat_state = self.chat_state_prev
258
+ return self.__call__(message, prefix=response)
259
 
260
  def free_chat(self, message):
261
+ message_ids = ChatHelper.dialog_tokenizer.encode(message + ChatHelper.eos, return_tensors='pt')
262
  self.history['input_ids'] = torch.cat([self.history['input_ids'], message_ids], dim=-1)
263
  input_ids = self.history['input_ids'].clone()
264
 
265
  while True:
266
+ bot_output_ids = ChatHelper.dialog_model.generate(input_ids, max_length=1000,
267
  do_sample=True, top_p=0.9, temperature=0.8, num_beams=2,
268
+ pad_token_id=ChatHelper.dialog_tokenizer.eos_token_id)
269
+ response = ChatHelper.dialog_tokenizer.decode(bot_output_ids[0][input_ids.shape[-1]:],
270
  skip_special_tokens=True)
271
  if response.strip() != '':
272
  break
273
+ elif input_ids[0].tolist().count(ChatHelper.dialog_tokenizer.eos_token_id) > 0:
274
+ idx = input_ids[0].tolist().index(ChatHelper.dialog_tokenizer.eos_token_id)
275
  input_ids = input_ids[:, (idx+1):]
276
  else:
277
  input_ids = message_ids
 
286
  return response
287
 
288
 
289
+ if __name__ == '__main__':
290
+ def chat(message, bot):
291
+ bot = bot or TherapyChatBot(args)
292
+ bot(message)
293
+ return bot.history['text'], bot
294
 
295
  title = 'PsyPlus Empathetic Chatbot'
296
  description = 'Gradio demo for product of PsyPlus. Based on rule-based CBT and conversational AI model DialoGPT'
297
+ greeting = [('', ChatHelper.greeting_template[args.initial_chat_state])]
298
+ chatbot = gr.Chatbot(value=greeting)
299
  iface = gr.Interface(
300
+ chat, ['text', 'state'], [chatbot, 'state'],
301
  allow_flagging='never', title=title, description=description,
302
  )
303
 
 
304
  if args.run_on_own_server == 0:
305
  iface.launch(debug=True)
306
  else:
307
+ iface.launch(debug=True, share=True)