Spaces:
Running
Running
import streamlit as st | |
from Bio import SeqIO | |
import torch | |
import torch.nn as nn | |
from esm.model.esm2 import ESM2 as ESM2_SISS | |
from esm import Alphabet, FastaBatchedDataset | |
import pandas as pd | |
from tqdm import tqdm | |
from io import StringIO | |
seed = 1337 | |
torch.manual_seed(seed) | |
global modelfile, layers, heads, embed_dim, batch_toks, inp_len, device | |
modelfile = 'model.pkl' | |
layers = 6 | |
heads = 16 | |
embed_dim = 128 | |
batch_toks = 4096 | |
inp_len = 50 | |
device = "cpu" | |
alphabet = Alphabet(prepend_toks=("<pad>", "<eos>", "<unk>"), standard_toks = 'AGCT', append_toks=("<cls>", "<mask>", "<sep>")) | |
alphabet.tok_to_idx = {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9} | |
class CNN_linear(nn.Module): | |
def __init__(self, | |
border_mode='same', filter_len=8, nbr_filters=120, | |
dropout1=0, dropout2=0): | |
super(CNN_linear, self).__init__() | |
self.embedding_size = embed_dim | |
self.border_mode = border_mode | |
self.inp_len = inp_len | |
self.nodes = 40 | |
self.cnn_layers = 0 | |
self.filter_len = filter_len | |
self.nbr_filters = nbr_filters | |
self.dropout1 = dropout1 | |
self.dropout2 = dropout2 | |
self.dropout3 = 0.5 | |
self.esm2 = ESM2_SISS(num_layers = layers, | |
embed_dim = embed_dim, | |
attention_heads = heads, | |
alphabet = alphabet) | |
self.conv1 = nn.Conv1d(in_channels = self.embedding_size, | |
out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode) | |
self.conv2 = nn.Conv1d(in_channels = self.nbr_filters, | |
out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode) | |
self.dropout1 = nn.Dropout(self.dropout1) | |
self.dropout2 = nn.Dropout(self.dropout2) | |
self.dropout3 = nn.Dropout(self.dropout3) | |
self.relu = nn.ReLU() | |
self.flatten = nn.Flatten() | |
self.fc = nn.Linear(in_features = embed_dim, out_features = self.nodes) | |
self.linear = nn.Linear(in_features = self.nbr_filters, out_features = self.nodes) | |
self.output = nn.Linear(in_features = self.nodes, out_features = 1) | |
self.direct_output = nn.Linear(in_features = embed_dim, out_features = 1) | |
self.magic_output = nn.Linear(in_features = 1, out_features = 1) | |
def forward(self, tokens, need_head_weights=True, return_contacts=False, return_representation=True): | |
# x = self.esm2(tokens, [layers], need_head_weights, return_contacts, return_representation) | |
x = self.esm2(tokens, [layers]) | |
x = x["representations"][layers][:, 0] | |
x_o = x.unsqueeze(2) | |
x = self.flatten(x_o) | |
o_linear = self.fc(x) | |
o_relu = self.relu(o_linear) | |
o_dropout = self.dropout3(o_relu) | |
o = self.output(o_dropout) | |
return o | |
def eval_step(dataloader, model, threshold = 0.5): | |
model.eval() | |
logits_list= [] | |
# y_pred_list, y_prob_list = [], [] | |
ids_list, strs_list = [], [] | |
my_bar = st.progress(0, text="Running UTR_LM") | |
with torch.no_grad(): | |
# for (ids, strs, _, toks, _, _) in tqdm(dataloader): | |
for i, (ids, strs, toks) in enumerate(dataloader): | |
ids_list.extend(ids) | |
strs_list.extend(strs) | |
# toks = toks.to(device) | |
my_bar.progress((i+1)/len(dataloader), text="Running UTR_LM") | |
# print(toks) | |
logits = model(toks) | |
logits = logits.reshape(-1) | |
# y_prob = torch.sigmoid(logits) | |
# y_pred = (y_prob > threshold).long() | |
logits_list.extend(logits.tolist()) | |
# y_prob_list.extend(y_prob.tolist()) | |
# y_pred_list.extend(y_pred.tolist()) | |
st.success('Done', icon="✅") | |
data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "Translation Efficiency":logits_list}) | |
return data_pred | |
def read_raw(raw_input): | |
ids = [] | |
sequences = [] | |
file = StringIO(raw_input) | |
for record in SeqIO.parse(file, "fasta"): | |
# 检查序列是否只包含A, G, C, T | |
sequence = str(record.seq.back_transcribe()).upper()[-inp_len:] | |
if not set(sequence).issubset(set("AGCT")): | |
st.write(f"Record '{record.description}' was skipped for containing invalid characters. Only A, G, C, T(U) are allowed.") | |
continue | |
# 将符合条件的序列添加到列表中 | |
ids.append(record.id) | |
sequences.append(sequence) | |
return ids, sequences | |
def generate_dataset_dataloader(ids, seqs): | |
# dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0) | |
dataset = FastaBatchedDataset(ids, seqs) | |
batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=2) | |
dataloader = torch.utils.data.DataLoader(dataset, | |
collate_fn=alphabet.get_batch_converter(), | |
batch_sampler=batches, | |
shuffle = False) | |
# dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=batches, shuffle = False) | |
st.write(f"{len(dataset)} sequences") | |
return dataset, dataloader | |
def predict_raw(raw_input): | |
# st.write('====Parse Input====') | |
ids, seqs = read_raw(raw_input) | |
_, dataloader = generate_dataset_dataloader(ids, seqs) | |
model = CNN_linear() | |
model.load_state_dict(torch.load(modelfile, map_location=torch.device('cpu')), strict = False) | |
# st.write('====Predict====') | |
pred = eval_step(dataloader, model) | |
# print(pred) | |
return pred | |
st.title("5' UTR prediction") | |
st.subheader("Input sequence") | |
seq = st.text_area("FASTA format only", value="") | |
st.subheader("Upload sequence file") | |
uploaded = st.file_uploader("Sequence file in FASTA format") | |
if st.button("Predict"): | |
if uploaded: | |
result = predict_raw(uploaded.getvalue().decode()) | |
result_file = result.to_csv(index=False) | |
st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv") | |
st.dataframe(result) | |
else: | |
result = predict_raw(seq) | |
result_file = result.to_csv(index=False) | |
st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv") | |
st.dataframe(result) | |