CLIPfa-Demo / app.py
SajjadAyoubi's picture
Create app.py
4f12085
raw
history blame
2.87 kB
import streamlit as st
import pandas as pd, numpy as np
from html import escape
import os
import torch
from transformers import RobertaModel, AutoTokenizer
@st.cache(show_spinner=False,
hash_funcs={text_encoder: lambda _: None,
tokenizer: lambda _: None,
dict: lambda _: None})
def load():
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text')
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text')
df = pd.read_csv('data.csv')
image_embeddings = np.load('embeddings.npy')
return text_encoder, tokenizer, df, image_embeddings
text_encoder, tokenizer, df, image_embeddings = load()
def get_html(url_list, height=224):
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for url, link in url_list:
html2 = f"<img style='height: {height}px; margin: 5px' src='{escape(url)}'>"
if len(link) > 0:
html2 = f"<a href='{escape(link)}' target='_blank'>" + \
html2 + "</a>"
html = html + html2
html += "</div>"
return html
st.cache(show_spinner=False)
def image_search(query, top_k=8):
torch.no_grad():
text_embedding = text_encoder(
**tokenizer(query, return_tensors='pt')).pooler_output
values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True)
return [(df.iloc[i]['path'], df.iloc[i]['link']) for i in indices[:top_k]]
description = '''
# Semantic image search :)
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
div.row-widget.stRadio > div{
flex-direction:row;
display: flex;
justify-content: center;
}
div.row-widget.stRadio > div > label{
margin-left: 5px;
margin-right: 5px;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input('', value='clouds at sunset')
if len(query) > 0:
results = image_search(query)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main()