Teera's picture
v
e506efc verified
raw
history blame
2.39 kB
import gradio as gr
import pandas as pd
import faiss
import numpy as np
import os
from FlagEmbedding import BGEM3FlagModel
from io import BytesIO
# Load the pre-trained embedding model
model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
# Load the JSON data into a DataFrame
df = pd.read_json('White-Stride-Red-68.json')
df['embeding_context'] = df['embeding_context'].astype(str).fillna('')
# Filter out any rows where 'embeding_context' might be empty or invalid
df = df[df['embeding_context'] != '']
# Load the FAISS index
index = faiss.read_index('vector_store_bge_m3.index')
# Function to perform search and return all columns
def search_query(query_text):
num_records = 50
# Encode the input query text
embeddings_query = model.encode([query_text], batch_size=12, max_length=1024)['dense_vecs']
embeddings_query_np = np.array(embeddings_query).astype('float32')
# Search in FAISS index for nearest neighbors
distances, indices = index.search(embeddings_query_np, num_records)
# Get the top results based on FAISS indices
result_df = df.iloc[indices[0]].drop(columns=['embeding_context']).drop_duplicates().reset_index(drop=True)
return result_df
# Gradio interface function
def gradio_interface(query_text):
search_results = search_query(query_text)
# Save search_results to an Excel file in memory
output = BytesIO()
with pd.ExcelWriter(output, engine='xlsxwriter') as writer:
search_results.to_excel(writer, index=False)
excel_data = output.getvalue()
# Return the DataFrame and update the download button
return search_results, gr.update(value=excel_data)
with gr.Blocks() as app:
gr.Markdown("<h1>White Stride Red Search (BEG-M3)</h1>")
# Input text box for the search query
search_input = gr.Textbox(label="Search Query", placeholder="Enter search text", interactive=True)
# Search button below the text box
search_button = gr.Button("Search")
# Output table for displaying results
search_output = gr.DataFrame(label="Search Results")
# Download button for Excel file
download_button = gr.DownloadButton(label="Download Excel", file_name="search_results.xlsx")
# Link button click to action
search_button.click(fn=gradio_interface, inputs=search_input, outputs=[search_output, download_button])
# Launch the Gradio app
app.launch()