SajjadAyoubi commited on
Commit
4f12085
1 Parent(s): fd91484

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd, numpy as np
3
+ from html import escape
4
+ import os
5
+ import torch
6
+ from transformers import RobertaModel, AutoTokenizer
7
+
8
+
9
+ @st.cache(show_spinner=False,
10
+ hash_funcs={text_encoder: lambda _: None,
11
+ tokenizer: lambda _: None,
12
+ dict: lambda _: None})
13
+ def load():
14
+ text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
15
+ tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
16
+ df = pd.read_csv('data.csv')
17
+ image_embeddings = np.load('embeddings.npy')
18
+ return text_encoder, tokenizer, df, image_embeddings
19
+
20
+
21
+ text_encoder, tokenizer, df, image_embeddings = load()
22
+
23
+
24
+ def get_html(url_list, height=224):
25
+ html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
26
+ for url, link in url_list:
27
+ html2 = f"<img style='height: {height}px; margin: 5px' src='{escape(url)}'>"
28
+ if len(link) > 0:
29
+ html2 = f"<a href='{escape(link)}' target='_blank'>" + \
30
+ html2 + "</a>"
31
+
32
+ html = html + html2
33
+ html += "</div>"
34
+ return html
35
+
36
+
37
+ st.cache(show_spinner=False)
38
+ def image_search(query, top_k=8):
39
+ torch.no_grad():
40
+ text_embedding = text_encoder(
41
+ **tokenizer(query, return_tensors='pt')).pooler_output
42
+ values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
43
+ return [(df.iloc[i]['path'], df.iloc[i]['link']) for i in indices[:top_k]]
44
+
45
+
46
+ description = '''
47
+ # Semantic image search :)
48
+ '''
49
+
50
+
51
+ def main():
52
+ st.markdown('''
53
+ <style>
54
+ .block-container{
55
+ max-width: 1200px;
56
+ }
57
+ div.row-widget.stRadio > div{
58
+ flex-direction:row;
59
+ display: flex;
60
+ justify-content: center;
61
+ }
62
+ div.row-widget.stRadio > div > label{
63
+ margin-left: 5px;
64
+ margin-right: 5px;
65
+ }
66
+ section.main>div:first-child {
67
+ padding-top: 0px;
68
+ }
69
+ section:not(.main)>div:first-child {
70
+ padding-top: 30px;
71
+ }
72
+ div.reportview-container > section:first-child{
73
+ max-width: 320px;
74
+ }
75
+ #MainMenu {
76
+ visibility: hidden;
77
+ }
78
+ footer {
79
+ visibility: hidden;
80
+ }
81
+ </style>''',
82
+ unsafe_allow_html=True)
83
+ st.sidebar.markdown(description)
84
+ _, c, _ = st.columns((1, 3, 1))
85
+ query = c.text_input('', value='clouds at sunset')
86
+ if len(query) > 0:
87
+ results = image_search(query)
88
+ st.markdown(get_html(results), unsafe_allow_html=True)
89
+
90
+
91
+ if __name__ == '__main__':
92
+ main()