File size: 2,056 Bytes
6c1b57e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import streamlit as st
from streamlit_chat import message

@st.cache(allow_output_mutation=True)
def get_pipe():
    from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
    tokenizer = AutoTokenizer.from_pretrained("heegyu/kodialogpt-v1")
    model = AutoModelForCausalLM.from_pretrained("heegyu/kodialogpt-v1")
    return pipeline("text-generation", model=model, tokenizer=tokenizer)

def get_response(generator, history, max_context: int = 7, bot_id: str = '1'):
    generation_args = dict(
        num_beams=4,
        repetition_penalty=2.0,
        no_repeat_ngram_size=4,
        eos_token_id=375, # \n
        max_new_tokens=64,
        do_sample=True,
        top_k=50,
        early_stopping=True
    )
    context = []
    for i, text in enumerate(history):
        context.append(f"{i % 2} : {text}\n")
    
    if len(context) > max_context:
        context = context[-max_context:]
    context = "".join(context) + f"{bot_id} : "

    # print(f"get_response({context})")

    response = generator(
        context,
        **generation_args
    )[0]["generated_text"]
    response = response[len(context):].split("\n")[0]
    return response

st.title("kodialogpt-v1 demo")

with st.spinner("loading model..."):
    generator = get_pipe()

if 'message_history' not in st.session_state:
    st.session_state.message_history =  []
history = st.session_state.message_history

# print(st.session_state.message_history)
for i, message_ in enumerate(st.session_state.message_history):
    message(message_,is_user=i % 2 == 0) # display all the previous message

# placeholder = st.empty() # placeholder for latest message
input_ = st.text_input("YOU", value="")

if input_ is not None and len(input_) > 0:
    if len(history) <= 1 or history[-2] != input_:
        with st.spinner("λŒ€λ‹΅μ„ μƒμ„±μ€‘μž…λ‹ˆλ‹€..."):
            st.session_state.message_history.append(input_)
            response = get_response(generator, history)
            st.session_state.message_history.append(response)
            st.experimental_rerun()