GGLS commited on
Commit
7261d63
1 Parent(s): 64e6ab8

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +118 -0
  2. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import time
4
+ from threading import Thread
5
+ from transformers import (
6
+ AutoTokenizer,
7
+ AutoModelForCausalLM,
8
+ TextIteratorStreamer
9
+ )
10
+
11
+
12
+ # App title
13
+ st.set_page_config(page_title="😶‍🌫️ FuseChat Model")
14
+
15
+ root_path = "FuseAI"
16
+
17
+ @st.cache_resource
18
+ def load_model(model_name):
19
+ tokenizer = AutoTokenizer.from_pretrained(
20
+ f"{root_path}/{model_name}",
21
+ trust_remote_code=True,
22
+ )
23
+
24
+ if tokenizer.pad_token_id is None:
25
+ if tokenizer.eos_token_id is not None:
26
+ tokenizer.pad_token_id = tokenizer.eos_token_id
27
+ else:
28
+ tokenizer.pad_token_id = 0
29
+
30
+ model = AutoModelForCausalLM.from_pretrained(
31
+ f"{root_path}/{model_name}",
32
+ device_map="auto",
33
+ torch_dtype=torch.bfloat16,
34
+ load_in_4bit=True,
35
+ trust_remote_code=True,
36
+ )
37
+
38
+ model.eval()
39
+ return model, tokenizer
40
+
41
+
42
+ with st.sidebar:
43
+ st.title('😶‍🌫️ FuseChat')
44
+ st.write('This chatbot is created using FuseChat, a model developed by FuseAI')
45
+ st.subheader('Models and parameters')
46
+ selected_model = st.sidebar.selectbox('Choose a FuseChat model', ['FuseChat-7B-VaRM', 'FuseChat-7B-Slerp', 'FuseChat-7B-TA'], key='selected_model')
47
+ temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.1, step=0.01)
48
+ top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01)
49
+ top_k = st.sidebar.slider('top_k', min_value=1, max_value=1000, value=50, step=1)
50
+ repetition_penalty = st.sidebar.slider('repetition penalty', min_value=1., max_value=2., value=1.2, step=0.05)
51
+ max_length = st.sidebar.slider('max new tokens', min_value=32, max_value=2000, value=240, step=8)
52
+
53
+ with st.spinner('loading model..'):
54
+ model, tokenizer = load_model(selected_model)
55
+
56
+ # Store LLM generated responses
57
+ if "messages" not in st.session_state.keys():
58
+ st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
59
+
60
+ # Display or clear chat messages
61
+ for message in st.session_state.messages:
62
+ with st.chat_message(message["role"]):
63
+ st.write(message["content"])
64
+
65
+ def clear_chat_history():
66
+ st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
67
+ st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
68
+
69
+
70
+ def generate_fusechat_response():
71
+ string_dialogue = "You are a helpful and harmless assistant."
72
+ for dict_message in st.session_state.messages:
73
+ if dict_message["role"] == "user":
74
+ string_dialogue += "GPT4 Correct User: " + dict_message["content"] + "<|end_of_turn|>"
75
+ else:
76
+ string_dialogue += "GPT4 Correct Assistant: " + dict_message["content"] + "<|end_of_turn|>"
77
+
78
+ input_ids = tokenizer(f"{string_dialogue}GPT4 Correct Assistant: ", return_tensors="pt").input_ids
79
+ streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
80
+ generate_kwargs = dict(
81
+ {"input_ids": input_ids},
82
+ streamer=streamer,
83
+ max_new_tokens=max_length,
84
+ do_sample=True,
85
+ top_p=top_p,
86
+ top_k=top_k,
87
+ temperature=temperature,
88
+ num_beams=1,
89
+ repetition_penalty=repetition_penalty,
90
+ )
91
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
92
+ t.start()
93
+
94
+ outputs = []
95
+ for text in streamer:
96
+ outputs.append(text)
97
+ return "".join(outputs)
98
+
99
+ # User-provided prompt
100
+ if prompt := st.chat_input("Hello there! How are you doing?"):
101
+ st.session_state.messages.append({"role": "user", "content": prompt})
102
+ with st.chat_message("user"):
103
+ st.write(prompt)
104
+
105
+ # Generate a new response if last message is not from assistant
106
+ if st.session_state.messages[-1]["role"] != "assistant":
107
+ with st.chat_message("assistant"):
108
+ with st.spinner("Thinking..."):
109
+ response = generate_fusechat_response()
110
+ placeholder = st.empty()
111
+ full_response = ''
112
+ for item in response:
113
+ full_response += item
114
+ time.sleep(0.05)
115
+ placeholder.markdown(full_response + "▌")
116
+ placeholder.markdown(full_response)
117
+ message = {"role": "assistant", "content": full_response}
118
+ st.session_state.messages.append(message)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit==1.29.0
2
+ bitsandbytes==0.42.0
3
+ accelerate==0.25.0
4
+ transformers==4.34.0
5
+ torch==2.1.2
6
+ protobuf==4.25.1
7
+ scipy==1.11.4
8
+ sentencepiece==0.1.99