Teera commited on
Commit
e506efc
1 Parent(s): 3359dd0
Files changed (1) hide show
  1. app.py +15 -32
app.py CHANGED
@@ -4,6 +4,7 @@ import faiss
4
  import numpy as np
5
  import os
6
  from FlagEmbedding import BGEM3FlagModel
 
7
 
8
  # Load the pre-trained embedding model
9
  model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
@@ -15,36 +16,9 @@ df['embeding_context'] = df['embeding_context'].astype(str).fillna('')
15
  # Filter out any rows where 'embeding_context' might be empty or invalid
16
  df = df[df['embeding_context'] != '']
17
 
18
- # # Encode the 'embeding_context' column
19
- # embedding_contexts = df['embeding_context'].tolist()
20
- # embeddings_csv = model.encode(embedding_contexts, batch_size=12, max_length=1024)['dense_vecs']
21
-
22
- # # Convert embeddings to numpy array
23
- # embeddings_np = np.array(embeddings_csv).astype('float32')
24
-
25
- # # FAISS index file path
26
- # index_file_path = 'vector_store_bge_m3.index'
27
-
28
- # # Check if FAISS index file already exists
29
- # if os.path.exists(index_file_path):
30
- # # Load the existing FAISS index from file
31
- # index = faiss.read_index(index_file_path)
32
- # print("FAISS index loaded from file.")
33
- # else:
34
- # # Initialize FAISS index (for L2 similarity)
35
- # dim = embeddings_np.shape[1]
36
- # index = faiss.IndexFlatL2(dim)
37
-
38
- # # Add embeddings to the FAISS index
39
- # index.add(embeddings_np)
40
-
41
- # # Save the FAISS index to a file for future use
42
- # faiss.write_index(index, index_file_path)
43
- # print("FAISS index created and saved to file.")
44
-
45
  index = faiss.read_index('vector_store_bge_m3.index')
46
 
47
-
48
  # Function to perform search and return all columns
49
  def search_query(query_text):
50
  num_records = 50
@@ -64,7 +38,15 @@ def search_query(query_text):
64
  # Gradio interface function
65
  def gradio_interface(query_text):
66
  search_results = search_query(query_text)
67
- return search_results
 
 
 
 
 
 
 
 
68
 
69
  with gr.Blocks() as app:
70
  gr.Markdown("<h1>White Stride Red Search (BEG-M3)</h1>")
@@ -78,10 +60,11 @@ with gr.Blocks() as app:
78
  # Output table for displaying results
79
  search_output = gr.DataFrame(label="Search Results")
80
 
81
- # Link button click to action
82
- search_button.click(fn=gradio_interface, inputs=search_input, outputs=search_output)
83
-
84
 
 
 
85
 
86
  # Launch the Gradio app
87
  app.launch()
 
4
  import numpy as np
5
  import os
6
  from FlagEmbedding import BGEM3FlagModel
7
+ from io import BytesIO
8
 
9
  # Load the pre-trained embedding model
10
  model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
 
16
  # Filter out any rows where 'embeding_context' might be empty or invalid
17
  df = df[df['embeding_context'] != '']
18
 
19
+ # Load the FAISS index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  index = faiss.read_index('vector_store_bge_m3.index')
21
 
 
22
  # Function to perform search and return all columns
23
  def search_query(query_text):
24
  num_records = 50
 
38
  # Gradio interface function
39
  def gradio_interface(query_text):
40
  search_results = search_query(query_text)
41
+
42
+ # Save search_results to an Excel file in memory
43
+ output = BytesIO()
44
+ with pd.ExcelWriter(output, engine='xlsxwriter') as writer:
45
+ search_results.to_excel(writer, index=False)
46
+ excel_data = output.getvalue()
47
+
48
+ # Return the DataFrame and update the download button
49
+ return search_results, gr.update(value=excel_data)
50
 
51
  with gr.Blocks() as app:
52
  gr.Markdown("<h1>White Stride Red Search (BEG-M3)</h1>")
 
60
  # Output table for displaying results
61
  search_output = gr.DataFrame(label="Search Results")
62
 
63
+ # Download button for Excel file
64
+ download_button = gr.DownloadButton(label="Download Excel", file_name="search_results.xlsx")
 
65
 
66
+ # Link button click to action
67
+ search_button.click(fn=gradio_interface, inputs=search_input, outputs=[search_output, download_button])
68
 
69
  # Launch the Gradio app
70
  app.launch()