import streamlit as st
import pandas as pd, numpy as np
from html import escape
import os
import torch
from transformers import RobertaModel, AutoTokenizer
@st.cache(show_spinner=False,
hash_funcs={text_encoder: lambda _: None,
tokenizer: lambda _: None,
dict: lambda _: None})
def load():
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
df = pd.read_csv('data.csv')
image_embeddings = np.load('embeddings.npy')
return text_encoder, tokenizer, df, image_embeddings
text_encoder, tokenizer, df, image_embeddings = load()
def get_html(url_list, height=224):
html = "
"
for url, link in url_list:
html2 = f"
"
if len(link) > 0:
html2 = f"
" + \
html2 + ""
html = html + html2
html += "
"
return html
st.cache(show_spinner=False)
def image_search(query, top_k=8):
torch.no_grad():
text_embedding = text_encoder(
**tokenizer(query, return_tensors='pt')).pooler_output
values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
return [(df.iloc[i]['path'], df.iloc[i]['link']) for i in indices[:top_k]]
description = '''
# Semantic image search :)
'''
def main():
st.markdown('''
''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input('', value='clouds at sunset')
if len(query) > 0:
results = image_search(query)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main()