|
import pinecone |
|
import streamlit as st |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
import streamlit_scrollable_textbox as stx |
|
import openai |
|
from utils import ( |
|
get_data, |
|
get_mpnet_embedding_model, |
|
get_sgpt_embedding_model, |
|
get_flan_t5_model, |
|
get_t5_model, |
|
save_key, |
|
) |
|
|
|
from utils import ( |
|
retrieve_transcript, |
|
query_pinecone, |
|
format_query, |
|
sentence_id_combine, |
|
text_lookup, |
|
generate_prompt, |
|
gpt_model, |
|
) |
|
|
|
|
|
st.title("Abstractive Question Answering") |
|
|
|
|
|
st.write( |
|
"The app uses the quarterly earnings call transcripts for 10 companies (Apple, AMD, Amazon, Cisco, Google, Microsoft, Nvidia, ASML, Intel, Micron) for the years 2016 to 2020." |
|
) |
|
|
|
col1, col2 = st.columns([3, 3], gap="medium") |
|
|
|
with col1: |
|
st.subheader("Question") |
|
query_text = st.text_input( |
|
"Input Query", |
|
value="What was discussed regarding Wearables revenue performance?", |
|
) |
|
|
|
with col1: |
|
years_choice = ["2020", "2019", "2018", "2017", "2016", "All"] |
|
|
|
with col1: |
|
year = st.selectbox("Year", years_choice) |
|
|
|
with col1: |
|
quarter = st.selectbox("Quarter", ["Q1", "Q2", "Q3", "Q4", "All"]) |
|
|
|
with col1: |
|
participant_type = st.selectbox("Speaker", ["Company Speaker", "Analyst"]) |
|
|
|
ticker_choice = [ |
|
"AAPL", |
|
"CSCO", |
|
"MSFT", |
|
"ASML", |
|
"NVDA", |
|
"GOOGL", |
|
"MU", |
|
"INTC", |
|
"AMZN", |
|
"AMD", |
|
] |
|
|
|
with col1: |
|
ticker = st.selectbox("Company", ticker_choice) |
|
|
|
with st.sidebar: |
|
st.subheader("Select Options:") |
|
|
|
with st.sidebar: |
|
num_results = int(st.number_input("Number of Results to query", 1, 15, value=6)) |
|
|
|
|
|
|
|
|
|
encoder_models_choice = ["MPNET", "SGPT"] |
|
with st.sidebar: |
|
encoder_model = st.selectbox("Select Encoder Model", encoder_models_choice) |
|
|
|
|
|
|
|
|
|
decoder_models_choice = [ |
|
"GPT3 - (text-davinci-003)", |
|
"T5", |
|
"FLAN-T5", |
|
] |
|
|
|
with st.sidebar: |
|
decoder_model = st.selectbox("Select Decoder Model", decoder_models_choice) |
|
|
|
|
|
if encoder_model == "MPNET": |
|
|
|
pinecone.init(api_key=st.secrets["pinecone_mpnet"], environment="us-east1-gcp") |
|
pinecone_index_name = "week2-all-mpnet-base" |
|
pinecone_index = pinecone.Index(pinecone_index_name) |
|
retriever_model = get_mpnet_embedding_model() |
|
|
|
elif encoder_model == "SGPT": |
|
|
|
pinecone.init(api_key=st.secrets["pinecone_sgpt"], environment="us-east1-gcp") |
|
pinecone_index_name = "week2-sgpt-125m" |
|
pinecone_index = pinecone.Index(pinecone_index_name) |
|
retriever_model = get_sgpt_embedding_model() |
|
|
|
|
|
with st.sidebar: |
|
window = int(st.number_input("Sentence Window Size", 0, 10, value=1)) |
|
|
|
with st.sidebar: |
|
threshold = float( |
|
st.number_input( |
|
label="Similarity Score Threshold", step=0.05, format="%.2f", value=0.25 |
|
) |
|
) |
|
|
|
data = get_data() |
|
|
|
query_results = query_pinecone( |
|
query_text, |
|
num_results, |
|
retriever_model, |
|
pinecone_index, |
|
year, |
|
quarter, |
|
ticker, |
|
participant_type, |
|
threshold, |
|
) |
|
|
|
if threshold <= 0.90: |
|
context_list = sentence_id_combine(data, query_results, lag=window) |
|
else: |
|
context_list = format_query(query_results) |
|
|
|
|
|
prompt = generate_prompt(query_text, context_list) |
|
|
|
if decoder_model == "GPT3 - (text-davinci-003)": |
|
with col2: |
|
with st.form("my_form"): |
|
edited_prompt = st.text_area(label="Model Prompt", value=prompt, height=270) |
|
|
|
openai_key = st.text_input( |
|
"Enter OpenAI key", |
|
value="", |
|
type="password", |
|
) |
|
submitted = st.form_submit_button("Submit") |
|
if submitted: |
|
api_key = save_key(openai_key) |
|
openai.api_key = api_key |
|
generated_text = gpt_model(edited_prompt) |
|
with col2: |
|
st.subheader("Answer:") |
|
st.write(generated_text) |
|
|
|
elif decoder_model == "T5": |
|
t5_pipeline = get_t5_model() |
|
output_text = [] |
|
for context_text in context_list: |
|
output_text.append(t5_pipeline(context_text)[0]["summary_text"]) |
|
generated_text = ". ".join(output_text) |
|
with col2: |
|
st.subheader("Answer:") |
|
st.write(t5_pipeline(generated_text)[0]["summary_text"]) |
|
|
|
elif decoder_model == "FLAN-T5": |
|
flan_t5_pipeline = get_flan_t5_model() |
|
output_text = [] |
|
for context_text in context_list: |
|
output_text.append(flan_t5_pipeline(context_text)[0]["summary_text"]) |
|
generated_text = ". ".join(output_text) |
|
with col2: |
|
st.subheader("Answer:") |
|
st.write(flan_t5_pipeline(generated_text)[0]["summary_text"]) |
|
|
|
with col1: |
|
with st.expander("See Retrieved Text"): |
|
for context_text in context_list: |
|
st.markdown(f"- {context_text}") |
|
|
|
file_text = retrieve_transcript(data, year, quarter, ticker) |
|
|
|
with col1: |
|
with st.expander("See Transcript"): |
|
stx.scrollableTextbox( |
|
file_text, height=700, border=False, fontFamily="Helvetica" |
|
) |
|
|