|
import streamlit as st |
|
from transformers import AutoTokenizer |
|
import json |
|
import tempfile |
|
import os |
|
import uuid |
|
import copy |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
def sanitize_jinja2(jinja_lines): |
|
|
|
one_liner_jinja = "" |
|
for line in jinja_lines: |
|
one_liner_jinja += line.lstrip(" ").rstrip("\n") |
|
|
|
return one_liner_jinja |
|
|
|
@st.cache_resource |
|
def get_existing_templates(): |
|
return [None] + os.listdir("./templates") |
|
|
|
|
|
|
|
|
|
if 'tokenizer_json' not in st.session_state: |
|
st.session_state['tokenizer_json'] = None |
|
|
|
if 'tokenizer' not in st.session_state: |
|
st.session_state['tokenizer'] = None |
|
|
|
if 'repo_normalized_name' not in st.session_state: |
|
st.session_state['repo_normalized_name'] = None |
|
|
|
if 'repo_id' not in st.session_state: |
|
st.session_state['repo_id'] = None |
|
|
|
if 'input_jinja_template' not in st.session_state: |
|
st.session_state['input_jinja_template'] = "" |
|
|
|
if 'uuid' not in st.session_state: |
|
st.session_state['uuid'] = uuid.uuid4() |
|
os.makedirs(f"./tmp/{st.session_state['uuid']}") |
|
|
|
if 'successful_template' not in st.session_state: |
|
st.session_state['successful_template'] = '' |
|
|
|
if not os.path.exists("./tmp"): |
|
os.makedirs("./tmp") |
|
|
|
title_description = """ |
|
Chat Template Generation: Make Chat Easier with Huggingface Tokenizer |
|
""" |
|
|
|
st.title(title_description) |
|
st.markdown('This streamlit app is to serve as an easier way to check and push the chat template to your/exisiting huggingface repo') |
|
|
|
list_of_templates = get_existing_templates() |
|
with st.expander("Current predefined templates"): |
|
for model in list_of_templates[1:]: |
|
st.markdown(f"- {model}") |
|
st.info('More templates will be predefined for easier setup of chat template.', icon="ℹ️") |
|
|
|
st.divider() |
|
|
|
|
|
hf_model_repo_name = st.text_input("Hugging Face Model Repository To Update", value="tiiuae/falcon-7b", max_chars=None, key=None, type="default", |
|
help=None, autocomplete=None, label_visibility="visible") |
|
|
|
gen_button = st.button("Get Tokenizer Config") |
|
|
|
if gen_button: |
|
with st.spinner(text="In progress...", cache=False): |
|
st.session_state['repo_id'] = hf_model_repo_name |
|
st.session_state['tokenizer'] = AutoTokenizer.from_pretrained(hf_model_repo_name) |
|
|
|
st.session_state['repo_normalized_name'] = hf_model_repo_name.replace("/", "_") |
|
st.session_state['tokenizer'].save_pretrained(f"./tmp/{st.session_state['uuid']}_{hf_model_repo_name}") |
|
st.session_state['tokenizer_json'] = f"./tmp/{st.session_state['uuid']}_{hf_model_repo_name}" |
|
|
|
if st.session_state['tokenizer_json'] is not None: |
|
with open(f"{st.session_state['tokenizer_json']}/tokenizer_config.json", "rb") as f: |
|
tokenizer_json = json.load(f) |
|
|
|
json_spec, col2 = st.columns(spec=[0.3, 0.7]) |
|
|
|
|
|
with json_spec: |
|
st.markdown(f"### Tokenizer Config from {st.session_state['repo_normalized_name']}") |
|
st.json(json.dumps(tokenizer_json, indent=4)) |
|
|
|
with col2: |
|
chat = [ |
|
{"role": "system", "content": "You are a helpful assistant."}, |
|
{"role": "user", "content": "Hello, how are you?"}, |
|
{"role": "assistant", "content": "I'm doing great. How can I help you today?"}, |
|
{"role": "user", "content": "I'd like to show off how chat templating works!"}, |
|
] |
|
st.markdown("### Example Conversation") |
|
st.json(json.dumps(chat, indent=4)) |
|
|
|
prompt_template_col, prompt_template_output_col = st.columns(spec=[0.3, 0.7]) |
|
|
|
with prompt_template_col: |
|
list_of_templates = get_existing_templates() |
|
selected_template = st.selectbox("Choose Existing Template or Leave Blank. (If template is None, it will check current tokenizer's `chat_template` and `default_chat_template` fields)", |
|
options=list_of_templates, |
|
index=0, placeholder="Choose a template (If template is None, it will check current tokenizer `chat_template` and `default_chat_template` fields)", disabled=False, label_visibility="visible") |
|
|
|
generate_prompt_example_button = st.button("Generate Prompt", key="generate_prompt_example_button") |
|
|
|
|
|
|
|
|
|
|
|
|
|
if selected_template is not None: |
|
with open(f"./templates/{selected_template}", "r") as f: |
|
jinja_lines = f.readlines() |
|
st.session_state['input_jinja_template'] = "".join(jinja_lines) |
|
|
|
if selected_template is None: |
|
st.session_state['input_jinja_template'] = st.session_state['tokenizer'].chat_template |
|
if st.session_state['input_jinja_template'] is None: |
|
st.session_state['input_jinja_template'] = st.session_state['tokenizer'].default_chat_template |
|
|
|
|
|
st.session_state['input_jinja_template'] = st.text_area( |
|
"Jinja Chat Template", value=st.session_state['input_jinja_template'], |
|
height=500, placeholder=None, disabled=False, label_visibility="visible") |
|
|
|
|
|
with prompt_template_output_col: |
|
|
|
if generate_prompt_example_button: |
|
with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "w") as fp: |
|
fp.write(st.session_state['input_jinja_template']) |
|
with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "r") as f: |
|
jinja_lines = f.readlines() |
|
st.session_state['tokenizer'].chat_template = sanitize_jinja2(jinja_lines) |
|
generated_prompt_wo_add_generation_prompt = st.session_state['tokenizer'].apply_chat_template(chat, tokenize=False, add_generation_prompt= False) |
|
generated_prompt_w_add_generation_prompt = st.session_state['tokenizer'].apply_chat_template(chat, tokenize=False, add_generation_prompt= True) |
|
|
|
st.text_area( |
|
"Generate Prompt with `add_generation_prompt=False`", value=generated_prompt_wo_add_generation_prompt, |
|
height=300, placeholder=None, disabled=True, label_visibility="visible", key="generated_prompt_wo_add_generation_prompt") |
|
|
|
st.text_area( |
|
"Generate Prompt with `add_generation_prompt=True`", value=generated_prompt_w_add_generation_prompt, |
|
height=300, placeholder=None, disabled=True, label_visibility="visible", key="generated_prompt_w_add_generation_prompt") |
|
|
|
st.session_state['successful_template'] = copy.deepcopy(st.session_state['input_jinja_template']) |
|
|
|
if len(st.session_state['successful_template']) > 0: |
|
access_token_no_cache = st.text_input("HuggingFace Access Token API with Write Access", type="password", key="access_token_no_cache") |
|
commit_message_text_input = st.text_input("Commit Message", key="commit_message_text_input") |
|
to_private_checkbox = st.checkbox("To Private Repo", key="to_private_checkbox") |
|
push_to_hub_button = st.button("Push to Hub", key="push_to_hub_button") |
|
create_pr_checkbox = st.checkbox("Create PR (For Contribution 🤗)", key="create_pr_checkbox") |
|
if push_to_hub_button: |
|
with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "w") as fp: |
|
fp.write(st.session_state['successful_template']) |
|
with open(f"./tmp/{st.session_state['uuid']}/tmp_chat_template.json", "r") as f: |
|
successful_jinja_lines = f.readlines() |
|
st.session_state['tokenizer'].chat_template = sanitize_jinja2(successful_jinja_lines) |
|
try: |
|
with st.spinner(text="Pushing to hub ...", cache=False): |
|
st.session_state['tokenizer'].push_to_hub( |
|
repo_id=st.session_state['repo_id'], |
|
commit_message=commit_message_text_input, |
|
private=to_private_checkbox, |
|
token=access_token_no_cache, |
|
create_pr=create_pr_checkbox) |
|
except Exception as e: |
|
st.write(f"Repo id: {st.session_state['repo_id']}") |
|
st.write(str(e)) |
|
|