|
import streamlit as st |
|
import pandas as pd |
|
from html import escape |
|
import os |
|
import torch |
|
from transformers import RobertaModel, AutoTokenizer |
|
|
|
|
|
@st.cache(show_spinner=False) |
|
def load(): |
|
text_encoder = RobertaModel.from_pretrained('SajjadAyoubi/clip-fa-text') |
|
tokenizer = AutoTokenizer.from_pretrained('SajjadAyoubi/clip-fa-text') |
|
links = np.load('link.npy', allow_pickle=True) |
|
image_embeddings = torch.load('embeddings.pt') |
|
return text_encoder, tokenizer, links, image_embeddings |
|
|
|
|
|
text_encoder, tokenizer, links, image_embeddings = load() |
|
|
|
|
|
def get_html(url_list, height=224): |
|
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>" |
|
for url in url_list: |
|
html2 = f"<img style='height: {height}px; margin: 5px' src='{escape(url)}'>" |
|
html = html + html2 |
|
html += "</div>" |
|
return html |
|
|
|
|
|
@st.cache(show_spinner=False) |
|
def image_search(query, top_k=8): |
|
with torch.no_grad(): |
|
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output |
|
values, indices = torch.cosine_similarity(text_embedding, image_embeddings).sort(descending=True) |
|
return [links[i] for i in indices[:top_k]] |
|
|
|
|
|
description = ''' |
|
# Semantic image search :) |
|
''' |
|
|
|
|
|
def main(): |
|
st.markdown(''' |
|
<style> |
|
.block-container{ |
|
max-width: 1200px; |
|
} |
|
div.row-widget.stRadio > div{ |
|
flex-direction:row; |
|
display: flex; |
|
justify-content: center; |
|
} |
|
div.row-widget.stRadio > div > label{ |
|
margin-left: 5px; |
|
margin-right: 5px; |
|
} |
|
section.main>div:first-child { |
|
padding-top: 0px; |
|
} |
|
section:not(.main)>div:first-child { |
|
padding-top: 30px; |
|
} |
|
div.reportview-container > section:first-child{ |
|
max-width: 320px; |
|
} |
|
#MainMenu { |
|
visibility: hidden; |
|
} |
|
footer { |
|
visibility: hidden; |
|
} |
|
</style>''', |
|
unsafe_allow_html=True) |
|
st.sidebar.markdown(description) |
|
_, c, _ = st.columns((1, 3, 1)) |
|
query = c.text_input('', value='clouds at sunset') |
|
if len(query) > 0: |
|
results = image_search(query) |
|
st.markdown(get_html(results), unsafe_allow_html=True) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |