Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding=utf-8 | |
# Copyright 2023 Statistics and Machine Learning Research Group at HKUST. All rights reserved. | |
"""A simple shell chatbot implemented with lmflow APIs. | |
""" | |
import logging | |
import json | |
import os | |
import sys | |
sys.path.remove(os.path.abspath(os.path.dirname(sys.argv[0]))) | |
import torch | |
import warnings | |
import gradio as gr | |
from dataclasses import dataclass, field | |
from transformers import HfArgumentParser | |
from typing import Optional | |
from lmflow.datasets.dataset import Dataset | |
from lmflow.pipeline.auto_pipeline import AutoPipeline | |
from lmflow.models.auto_model import AutoModel | |
from lmflow.args import ModelArguments, DatasetArguments, AutoArguments | |
MAX_BOXES = 20 | |
logging.disable(logging.ERROR) | |
warnings.filterwarnings("ignore") | |
title = """ | |
<h1 align="center">LMFlow-CHAT</h1> | |
<link rel="stylesheet" href="/path/to/styles/default.min.css"> | |
<script src="/path/to/highlight.min.js"></script> | |
<script>hljs.highlightAll();</script> | |
<img src="https://optimalscale.github.io/LMFlow/_static/logo.png" alt="LMFlow" style="width: 30%; min-width: 60px; display: block; margin: auto; background-color: transparent;"> | |
<p>LMFlow is in extensible, convenient, and efficient toolbox for finetuning large machine learning models, designed to be user-friendly, speedy and reliable, and accessible to the entire community.</p> | |
<p>We have thoroughly tested this toolkit and are pleased to make it available under <a class="reference external" href="https://github.com/OptimalScale/LMFlow">Github</a>.</p> | |
""" | |
css = """ | |
#user { | |
float: right; | |
position:relative; | |
right:5px; | |
width:auto; | |
min-height:32px; | |
max-width: 60% | |
line-height: 32px; | |
padding: 2px 8px; | |
font-size: 14px; | |
background: #9DC284; | |
border-radius:5px; | |
margin:10px 0px; | |
} | |
#chatbot { | |
float: left; | |
position:relative; | |
right:5px; | |
width:auto; | |
min-height:32px; | |
max-width: 60% | |
line-height: 32px; | |
padding: 2px 8px; | |
font-size: 14px; | |
background:#7BA7D7; | |
border-radius:5px; | |
margin:10px 0px; | |
} | |
""" | |
class ChatbotArguments: | |
prompt_structure: Optional[str] = field( | |
default="###Human: {input_text}###Assistant:", | |
metadata={ | |
"help": "prompt structure given user's input text" | |
}, | |
) | |
end_string: Optional[str] = field( | |
default="#", | |
metadata={ | |
"help": "end string mark of the chatbot's output" | |
}, | |
) | |
max_new_tokens: Optional[int] = field( | |
default=1500, | |
metadata={ | |
"help": "maximum number of generated tokens" | |
}, | |
) | |
temperature: Optional[float] = field( | |
default=0.7, | |
metadata={ | |
"help": "higher this value, more random the model output" | |
}, | |
) | |
def main(): | |
pipeline_name = "inferencer" | |
PipelineArguments = AutoArguments.get_pipeline_args_class(pipeline_name) | |
parser = HfArgumentParser(( | |
ModelArguments, | |
PipelineArguments, | |
ChatbotArguments, | |
)) | |
model_args, pipeline_args, chatbot_args = ( | |
parser.parse_args_into_dataclasses() | |
) | |
model_args.model_name_or_path = "LMFlow/Full-Robin-13b-v2" | |
pipeline_args.deepspeed = "configs/ds_config_chatbot.json" | |
model_args.torch_dtype = "float16" | |
with open (pipeline_args.deepspeed, "r") as f: | |
ds_config = json.load(f) | |
model = AutoModel.get_model( | |
model_args, | |
tune_strategy='none', | |
ds_config=ds_config, | |
device=pipeline_args.device, | |
torch_dtype=torch.float16 | |
) | |
# We don't need input data, we will read interactively from stdin | |
data_args = DatasetArguments(dataset_path=None) | |
dataset = Dataset(data_args) | |
inferencer = AutoPipeline.get_pipeline( | |
pipeline_name=pipeline_name, | |
model_args=model_args, | |
data_args=data_args, | |
pipeline_args=pipeline_args, | |
) | |
# Chats | |
model_name = model_args.model_name_or_path | |
if model_args.lora_model_path is not None: | |
model_name += f" + {model_args.lora_model_path}" | |
# context = ( | |
# "You are a helpful assistant who follows the given instructions" | |
# " unconditionally." | |
# ) | |
end_string = chatbot_args.end_string | |
prompt_structure = chatbot_args.prompt_structure | |
token_per_step = 4 | |
def hist2context(hist): | |
context = "" | |
for query, response in hist: | |
context += prompt_structure.format(input_text=query) | |
if not (response is None): | |
context += response | |
return context | |
def chat_stream(query: str, history= None, **kwargs): | |
if history is None: | |
history = [] | |
context = hist2context(history) | |
print_index = 0 | |
context += prompt_structure.format(input_text=query) | |
context_ = context[-model.get_max_length():] | |
input_dataset = dataset.from_dict({ | |
"type": "text_only", | |
"instances": [ { "text": context_ } ] | |
}) | |
print(context_) | |
for response, flag_break in inferencer.stream_inference(context=context_, model=model, max_new_tokens=chatbot_args.max_new_tokens, | |
token_per_step=token_per_step, temperature=chatbot_args.temperature, | |
end_string=end_string, input_dataset=input_dataset): | |
delta = response[print_index:] | |
seq = response | |
print_index = len(response) | |
yield delta, history + [(query, seq)] | |
if flag_break: | |
break | |
def predict(input, history=None): | |
if history is None: | |
history = [] | |
for response, history in chat_stream(input, history): | |
updates = [] | |
for query, response in history: | |
updates.append(gr.update(visible=True, value="" + query)) | |
updates.append(gr.update(visible=True, value="" + response)) | |
if len(updates) < MAX_BOXES: | |
updates = updates + [gr.Textbox.update(visible=False)] * (MAX_BOXES - len(updates)) | |
yield [history] + updates | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(title) | |
state = gr.State([]) | |
text_boxes = [] | |
for i in range(MAX_BOXES): | |
if i % 2 == 0: | |
text_boxes.append(gr.Markdown(visible=False, label="Q:", elem_id="user")) | |
else: | |
text_boxes.append(gr.Markdown(visible=False, label="A:", elem_id="chatbot")) | |
txt = gr.Textbox( | |
show_label=False, | |
placeholder="Enter text and press send.", | |
) | |
button = gr.Button("Send") | |
button.click(predict, [txt, state], [state] + text_boxes) | |
demo.queue().launch() | |
if __name__ == "__main__": | |
main() | |