awinml's picture
Upload 3 files
bf8b612
raw
history blame
4.88 kB
import pinecone
import streamlit as st
st.set_page_config(layout="wide")
import streamlit_scrollable_textbox as stx
import openai
from utils import (
get_data,
get_mpnet_embedding_model,
get_sgpt_embedding_model,
get_flan_t5_model,
get_t5_model,
save_key,
)
from utils import (
retrieve_transcript,
query_pinecone,
format_query,
sentence_id_combine,
text_lookup,
generate_prompt,
gpt_model,
)
st.title("Abstractive Question Answering")
st.write(
"The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020."
)
col1, col2 = st.columns([3, 3], gap="medium")
with col1:
st.subheader("Question")
query_text = st.text_input(
"Input Query",
value="What was discussed regarding Wearables revenue performance?",
)
with col1:
years_choice = ["2020", "2019", "2018", "2017", "2016"]
with col1:
year = st.selectbox("Year", years_choice)
with col1:
quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4"])
ticker_choice = [
"AAPL",
"CSCO",
"MSFT",
"ASML",
"NVDA",
"GOOGL",
"MU",
"INTC",
"AMZN",
"AMD",
]
with col1:
ticker = st.selectbox("Company", ticker_choice)
with st.sidebar:
st.subheader("Select Options:")
with st.sidebar:
num_results = int(st.number_input("Number of Results to query", 1, 15, value=6))
# Choose encoder model
encoder_models_choice = ["MPNET", "SGPT"]
with st.sidebar:
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice)
# Choose decoder model
decoder_models_choice = [
"GPT3 - (text-davinci-003)",
"T5",
"FLAN-T5",
]
with st.sidebar:
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice)
if encoder_model == "MPNET":
# Connect to pinecone environment
pinecone.init(api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp")
pinecone_index_name = "week2-all-mpnet-base"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_mpnet_embedding_model()
elif encoder_model == "SGPT":
# Connect to pinecone environment
pinecone.init(api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp")
pinecone_index_name = "week2-sgpt-125m"
pinecone_index = pinecone.Index(pinecone_index_name)
retriever_model = get_sgpt_embedding_model()
with st.sidebar:
window = int(st.number_input("Sentence Window Size", 0, 10, value=1))
with st.sidebar:
threshold = float(
st.number_input(
label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.25
)
)
data = get_data()
query_results = query_pinecone(
query_text,
num_results,
retriever_model,
pinecone_index,
year,
quarter,
ticker,
threshold,
)
if threshold <= 0.90:
context_list = sentence_id_combine(data, query_results, lag=window)
else:
context_list = format_query(query_results)
prompt = generate_prompt(query_text, context_list)
if decoder_model == "GPT3 - (text-davinci-003)":
with col2:
with st.form("my_form"):
edited_prompt = st.text_area(label="Model Prompt", value=prompt, height=270)
openai_key = st.text_input(
"Enter OpenAI key",
value="",
type="password",
)
submitted = st.form_submit_button("Submit")
if submitted:
api_key = save_key(openai_key)
openai.api_key = api_key
generated_text = gpt_model(edited_prompt)
with col2:
st.subheader("Answer:")
st.write(generated_text)
elif decoder_model == "T5":
t5_pipeline = get_t5_model()
output_text = []
for context_text in context_list:
output_text.append(t5_pipeline(context_text)[0]["summary_text"])
generated_text = ". ".join(output_text)
with col2:
st.subheader("Answer:")
st.write(t5_pipeline(generated_text)[0]["summary_text"])
elif decoder_model == "FLAN-T5":
flan_t5_pipeline = get_flan_t5_model()
output_text = []
for context_text in context_list:
output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"])
generated_text = ". ".join(output_text)
with col2:
st.subheader("Answer:")
st.write(flan_t5_pipeline(generated_text)[0]["summary_text"])
with col1:
with st.expander("See Retrieved Text"):
for context_text in context_list:
st.markdown(f"- {context_text}")
file_text = retrieve_transcript(data, year, quarter, ticker)
with col1:
with st.expander("See Transcript"):
stx.scrollableTextbox(
file_text, height=700, border=False, fontFamily="Helvetica"
)