|
import streamlit as st |
|
from streamlit_chat import message |
|
|
|
@st.cache(allow_output_mutation=True) |
|
def get_pipe(): |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
tokenizer = AutoTokenizer.from_pretrained("heegyu/kodialogpt-v1") |
|
model = AutoModelForCausalLM.from_pretrained("heegyu/kodialogpt-v1") |
|
return pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
|
|
def get_response(generator, history, max_context: int = 7, bot_id: str = '1'): |
|
generation_args = dict( |
|
num_beams=4, |
|
repetition_penalty=2.0, |
|
no_repeat_ngram_size=4, |
|
eos_token_id=375, |
|
max_new_tokens=64, |
|
do_sample=True, |
|
top_k=50, |
|
early_stopping=True |
|
) |
|
context = [] |
|
for i, text in enumerate(history): |
|
context.append(f"{i % 2} : {text}\n") |
|
|
|
if len(context) > max_context: |
|
context = context[-max_context:] |
|
context = "".join(context) + f"{bot_id} : " |
|
|
|
|
|
|
|
response = generator( |
|
context, |
|
**generation_args |
|
)[0]["generated_text"] |
|
response = response[len(context):].split("\n")[0] |
|
return response |
|
|
|
st.title("kodialogpt-v1 demo") |
|
|
|
with st.spinner("loading model..."): |
|
generator = get_pipe() |
|
|
|
if 'message_history' not in st.session_state: |
|
st.session_state.message_history = [] |
|
history = st.session_state.message_history |
|
|
|
|
|
for i, message_ in enumerate(st.session_state.message_history): |
|
message(message_,is_user=i % 2 == 0) |
|
|
|
|
|
input_ = st.text_input("YOU", value="") |
|
|
|
if input_ is not None and len(input_) > 0: |
|
if len(history) <= 1 or history[-2] != input_: |
|
with st.spinner("λλ΅μ μμ±μ€μ
λλ€..."): |
|
st.session_state.message_history.append(input_) |
|
response = get_response(generator, history) |
|
st.session_state.message_history.append(response) |
|
st.experimental_rerun() |