|
import streamlit as st |
|
from streamlit_chat import message |
|
from ingest_data import embed_doc |
|
from query_data import get_chain |
|
import os |
|
import time |
|
|
|
st.set_page_config(page_title="LangChain Local PDF Chat", page_icon=":robot:") |
|
|
|
footer="""<style> |
|
|
|
.footer { |
|
position: fixed; |
|
left: 0; |
|
bottom: 0; |
|
width: 100%; |
|
background-color: white; |
|
color: black; |
|
text-align: right; |
|
} |
|
</style> |
|
<div class="footer"> |
|
<p>Adapted with ❤ and \U0001F916 by Fakezeta from the original Mobilefirst</p> |
|
</div> |
|
""" |
|
st.markdown(footer,unsafe_allow_html=True) |
|
|
|
def process_file(uploaded_file): |
|
with open(uploaded_file.name,"wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
st.write("File Uploaded successfully") |
|
|
|
with st.spinner("Document is being vectorized...."): |
|
vectorstore = embed_doc(uploaded_file.name) |
|
f.close() |
|
os.remove(uploaded_file.name) |
|
return vectorstore |
|
|
|
def get_text(): |
|
input_text = st.text_input("You: ", value="", key="input", disabled=st.session_state.disabled) |
|
return input_text |
|
|
|
def query(query): |
|
start = time.time() |
|
with st.spinner("Doing magic...."): |
|
if len(st.session_state.past) > 0 and len(st.session_state.generated) > 0: |
|
chat_history=[("HUMAN: "+st.session_state.past[-1], "ASSISTANT: "+st.session_state.generated[-1])] |
|
else: |
|
chat_history=[] |
|
print("chat_history:", chat_history) |
|
output = st.session_state.chain.run(input= query, |
|
question= query, |
|
vectorstore= st.session_state.vectorstore, |
|
chat_history= chat_history |
|
) |
|
end = time.time() |
|
print("Query time: \a "+str(round(end - start,1))) |
|
return output |
|
|
|
|
|
with open("style.css") as f: |
|
st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True) |
|
|
|
st.header("Local Chat with Pdf") |
|
|
|
if "uploaded_file_name" not in st.session_state: |
|
st.session_state.uploaded_file_name = "" |
|
|
|
if "past" not in st.session_state: |
|
st.session_state.past = [] |
|
|
|
if "generated" not in st.session_state: |
|
st.session_state["generated"] = [] |
|
|
|
if "vectorstore" not in st.session_state: |
|
st.session_state.vectorstore = None |
|
|
|
if "chain" not in st.session_state: |
|
st.session_state.chain = None |
|
|
|
uploaded_file = st.file_uploader("Choose a file", type=['pdf']) |
|
|
|
if uploaded_file: |
|
if uploaded_file.name != st.session_state.uploaded_file_name: |
|
st.session_state.vectorstore = None |
|
st.session_state.chain = None |
|
st.session_state["generated"] = [] |
|
st.session_state.past = [] |
|
st.session_state.uploaded_file_name = uploaded_file.name |
|
st.session_state.all_messages = [] |
|
print(st.session_state.uploaded_file_name) |
|
if not st.session_state.vectorstore: |
|
st.session_state.vectorstore = process_file(uploaded_file) |
|
|
|
if st.session_state.vectorstore and not st.session_state.chain: |
|
with st.spinner("Loading Large Language Model...."): |
|
st.session_state.chain=get_chain(st.session_state.vectorstore) |
|
searching=False |
|
user_input = st.text_input("You: ", value="", key="input", disabled=searching) |
|
send_button = st.button(label="Query") |
|
if send_button: |
|
searching = True |
|
output = query(user_input) |
|
searching = False |
|
st.session_state.past.append(user_input) |
|
st.session_state.generated.append(output) |
|
if st.session_state["generated"]: |
|
for i in range(len(st.session_state["generated"]) - 1, -1, -1): |
|
message(st.session_state["generated"][i], key=str(i)) |
|
message(st.session_state.past[i], is_user=True, key=str(i) + "_user") |
|
|