Spaces:
Running
Running
Shawn Shen
commited on
Commit
•
3aa4b4a
1
Parent(s):
7714123
consistent result
Browse files- Predictor.py +0 -276
- app.py +167 -15
- requirements.txt +3 -3
Predictor.py
DELETED
@@ -1,276 +0,0 @@
|
|
1 |
-
# CUDA_VISIBLE_DEVICES=2 python -m torch.distributed.launch --nproc_per_node=1 --master_port 3303 Predictor.py --predict_file /home/ubuntu/Experimental_Data/v1_5UTR_seqs_with_v1Label.fasta --outdir /home/ubuntu/Experimental_Data/try --outfilename try_RVACv1
|
2 |
-
|
3 |
-
|
4 |
-
import os
|
5 |
-
from Bio import SeqIO
|
6 |
-
import sys
|
7 |
-
|
8 |
-
# import argparse
|
9 |
-
# from argparse import Namespace
|
10 |
-
# import pathlib
|
11 |
-
|
12 |
-
import torch
|
13 |
-
import torch.nn as nn
|
14 |
-
import torch.nn.functional as F
|
15 |
-
|
16 |
-
# import esm
|
17 |
-
# from esm.data import *
|
18 |
-
# from esm.model.esm2_secondarystructure import ESM2 as ESM2_SISS
|
19 |
-
from esm.model.esm2 import ESM2 as ESM2_SISS
|
20 |
-
# from esm.model.esm2_supervised import ESM2
|
21 |
-
from esm import Alphabet, FastaBatchedDataset#, ProteinBertModel, pretrained, MSATransformer
|
22 |
-
|
23 |
-
|
24 |
-
import numpy as np
|
25 |
-
import pandas as pd
|
26 |
-
import random
|
27 |
-
# import math
|
28 |
-
# import scipy.stats as stats
|
29 |
-
# from scipy.stats import spearmanr, pearsonr
|
30 |
-
# from sklearn import preprocessing
|
31 |
-
# from copy import deepcopy
|
32 |
-
from tqdm import tqdm#, trange
|
33 |
-
# import matplotlib.pyplot as plt
|
34 |
-
# import seaborn as sns
|
35 |
-
# from sklearn.model_selection import KFold
|
36 |
-
# from torch.optim.lr_scheduler import StepLR
|
37 |
-
# import torch.distributed as dist
|
38 |
-
# from torch.nn.parallel import DistributedDataParallel
|
39 |
-
# from torch.utils.data.distributed import DistributedSampler
|
40 |
-
from io import StringIO
|
41 |
-
|
42 |
-
seed = 1337
|
43 |
-
random.seed(seed)
|
44 |
-
np.random.seed(seed)
|
45 |
-
torch.manual_seed(seed)
|
46 |
-
# torch.cuda.manual_seed(seed)
|
47 |
-
# torch.cuda.manual_seed_all(seed)
|
48 |
-
|
49 |
-
# parser = argparse.ArgumentParser()
|
50 |
-
# parser.add_argument('--device_ids', type=str, default='0', help="Training Devices")
|
51 |
-
# parser.add_argument('--local-rank', type=int, default=-1, help="DDP parameter, do not modify")
|
52 |
-
|
53 |
-
# parser.add_argument('--outdir', type=str, default = '/home/ubuntu/Experimental_Data/try')
|
54 |
-
# parser.add_argument('--outfilename', type=str, default = 'try_RVACv1')
|
55 |
-
# parser.add_argument('--predict_file', type = str, default = '/home/ubuntu/Experimental_Data/v1_5UTR_seqs_with_v1Label.fasta')
|
56 |
-
# args = parser.parse_args()
|
57 |
-
# print(args)
|
58 |
-
|
59 |
-
global modelfile, layers, heads, embed_dim, batch_toks, inp_len, device_ids, device
|
60 |
-
modelfile = 'model.pkl'
|
61 |
-
|
62 |
-
# model_info = modelfile.split('/')[-1].split('_')
|
63 |
-
# for item in model_info:
|
64 |
-
# if 'layers' in item:
|
65 |
-
# layers = int(item[0])
|
66 |
-
# elif 'heads' in item:
|
67 |
-
# heads = int(item[:-5])
|
68 |
-
# elif 'embedsize' in item:
|
69 |
-
# embed_dim = int(item[:-9])
|
70 |
-
# elif 'batchToks' in item:
|
71 |
-
# batch_toks = 4096
|
72 |
-
|
73 |
-
layers = 6
|
74 |
-
heads = 16
|
75 |
-
embed_dim = 128
|
76 |
-
batch_toks = 4096
|
77 |
-
|
78 |
-
inp_len = 50
|
79 |
-
|
80 |
-
# device_ids = list(map(int, args.device_ids.split(',')))
|
81 |
-
# dist.init_process_group(backend='nccl')
|
82 |
-
# device = torch.device('cuda:{}'.format(device_ids[args.local_rank]))
|
83 |
-
device = "cpu"
|
84 |
-
# torch.cuda.set_device(device)
|
85 |
-
|
86 |
-
# local_rank = args.local_rank
|
87 |
-
local_rank = -1
|
88 |
-
# storage_id = int(device_ids[local_rank])
|
89 |
-
storage_id = 0
|
90 |
-
|
91 |
-
# repr_layers = [layers]
|
92 |
-
include = ["mean"]
|
93 |
-
|
94 |
-
class CNN_linear(nn.Module):
|
95 |
-
def __init__(self,
|
96 |
-
border_mode='same', filter_len=8, nbr_filters=120,
|
97 |
-
dropout1=0, dropout2=0):
|
98 |
-
|
99 |
-
super(CNN_linear, self).__init__()
|
100 |
-
|
101 |
-
self.embedding_size = embed_dim
|
102 |
-
self.border_mode = border_mode
|
103 |
-
self.inp_len = inp_len
|
104 |
-
self.nodes = 40
|
105 |
-
self.cnn_layers = 0
|
106 |
-
self.filter_len = filter_len
|
107 |
-
self.nbr_filters = nbr_filters
|
108 |
-
self.dropout1 = dropout1
|
109 |
-
self.dropout2 = dropout2
|
110 |
-
self.dropout3 = 0.5
|
111 |
-
|
112 |
-
self.esm2 = ESM2_SISS(num_layers = layers,
|
113 |
-
embed_dim = embed_dim,
|
114 |
-
attention_heads = heads,
|
115 |
-
alphabet = alphabet)
|
116 |
-
|
117 |
-
self.conv1 = nn.Conv1d(in_channels = self.embedding_size,
|
118 |
-
out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode)
|
119 |
-
self.conv2 = nn.Conv1d(in_channels = self.nbr_filters,
|
120 |
-
out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode)
|
121 |
-
|
122 |
-
self.dropout1 = nn.Dropout(self.dropout1)
|
123 |
-
self.dropout2 = nn.Dropout(self.dropout2)
|
124 |
-
self.dropout3 = nn.Dropout(self.dropout3)
|
125 |
-
self.relu = nn.ReLU()
|
126 |
-
self.flatten = nn.Flatten()
|
127 |
-
self.fc = nn.Linear(in_features = embed_dim, out_features = self.nodes)
|
128 |
-
self.linear = nn.Linear(in_features = self.nbr_filters, out_features = self.nodes)
|
129 |
-
self.output = nn.Linear(in_features = self.nodes, out_features = 1)
|
130 |
-
self.direct_output = nn.Linear(in_features = embed_dim, out_features = 1)
|
131 |
-
self.magic_output = nn.Linear(in_features = 1, out_features = 1)
|
132 |
-
|
133 |
-
def forward(self, tokens, need_head_weights=True, return_contacts=False, return_representation=True):
|
134 |
-
|
135 |
-
# x = self.esm2(tokens, [layers], need_head_weights, return_contacts, return_representation)
|
136 |
-
x = self.esm2(tokens, [layers])
|
137 |
-
|
138 |
-
x = x["representations"][layers][:, 0]
|
139 |
-
x_o = x.unsqueeze(2)
|
140 |
-
|
141 |
-
x = self.flatten(x_o)
|
142 |
-
o_linear = self.fc(x)
|
143 |
-
o_relu = self.relu(o_linear)
|
144 |
-
o_dropout = self.dropout3(o_relu)
|
145 |
-
o = self.output(o_dropout)
|
146 |
-
return o
|
147 |
-
|
148 |
-
def eval_step(dataloader, model, threshold = 0.5):
|
149 |
-
model.eval()
|
150 |
-
y_pred_list, y_prob_list = [], []
|
151 |
-
ids_list, strs_list = [], []
|
152 |
-
with torch.no_grad():
|
153 |
-
# for (ids, strs, _, toks, _, _) in tqdm(dataloader):
|
154 |
-
for ids, strs, toks in tqdm(dataloader):
|
155 |
-
ids_list.extend(ids)
|
156 |
-
strs_list.extend(strs)
|
157 |
-
# toks = toks.to(device)
|
158 |
-
|
159 |
-
# print(toks)
|
160 |
-
logits = model(toks)
|
161 |
-
|
162 |
-
logits = logits.reshape(-1)
|
163 |
-
y_prob = torch.sigmoid(logits)
|
164 |
-
y_pred = (y_prob > threshold).long()
|
165 |
-
|
166 |
-
|
167 |
-
y_prob_list.extend(y_prob.cpu().detach().tolist())
|
168 |
-
y_pred_list.extend(y_pred.cpu().detach().tolist())
|
169 |
-
|
170 |
-
data_pred = pd.DataFrame([ids_list, strs_list, y_prob_list, y_pred_list], index = ['ID', 'Sequence', "Probability as 5'UTR", "Prediction as 5'UTR"]).T
|
171 |
-
return data_pred
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
def generate_dataset_dataloader(ids, seqs):
|
176 |
-
# dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0)
|
177 |
-
dataset = FastaBatchedDataset(ids, seqs)
|
178 |
-
batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=2)
|
179 |
-
dataloader = torch.utils.data.DataLoader(dataset,
|
180 |
-
collate_fn=alphabet.get_batch_converter(),
|
181 |
-
batch_sampler=batches,
|
182 |
-
shuffle = False)
|
183 |
-
print(f"{len(dataset)} sequences")
|
184 |
-
return dataset, dataloader
|
185 |
-
|
186 |
-
def read_fasta(file):
|
187 |
-
# 判断文件是否为空
|
188 |
-
if os.path.getsize(file) == 0:
|
189 |
-
print("Error: The file is empty!")
|
190 |
-
sys.exit()
|
191 |
-
|
192 |
-
ids = []
|
193 |
-
sequences = []
|
194 |
-
|
195 |
-
for record in SeqIO.parse(file, "fasta"):
|
196 |
-
# 检查序列的开头是否为">"
|
197 |
-
# if not record.id.startswith('>'):
|
198 |
-
# print(f"Error: The sequence '{record.id}' is not properly formatted, it does not start with '>'. Skipping...")
|
199 |
-
# continue
|
200 |
-
|
201 |
-
# 检查序列是否只包含A, G, C, T
|
202 |
-
sequence = str(record.seq).upper()[-inp_len:]
|
203 |
-
if not set(sequence).issubset(set("AGCT")):
|
204 |
-
print(f"Error: The sequence '{record.description}' contains invalid characters. Only A, G, C, T are allowed. Skipping...")
|
205 |
-
continue
|
206 |
-
|
207 |
-
# 将符合条件的序列添加到列表中
|
208 |
-
ids.append(record.id)
|
209 |
-
sequences.append(sequence)
|
210 |
-
|
211 |
-
return ids, sequences
|
212 |
-
|
213 |
-
def read_raw(raw_input):
|
214 |
-
ids = []
|
215 |
-
sequences = []
|
216 |
-
|
217 |
-
file = StringIO(raw_input)
|
218 |
-
for record in SeqIO.parse(file, "fasta"):
|
219 |
-
# 检查序列的开头是否为">"
|
220 |
-
# if not record.id.startswith('>'):
|
221 |
-
# print(f"Error: The sequence '{record.id}' is not properly formatted, it does not start with '>'. Skipping...")
|
222 |
-
# continue
|
223 |
-
|
224 |
-
# 检查序列是否只包含A, G, C, T
|
225 |
-
sequence = str(record.seq).upper()[-inp_len:]
|
226 |
-
if not set(sequence).issubset(set("AGCT")):
|
227 |
-
print(f"Error: The sequence '{record.description}' contains invalid characters. Only A, G, C, T are allowed. Skipping...")
|
228 |
-
continue
|
229 |
-
|
230 |
-
# 将符合条件的序列添加到列表中
|
231 |
-
ids.append(record.id)
|
232 |
-
sequences.append(sequence)
|
233 |
-
|
234 |
-
return ids, sequences
|
235 |
-
|
236 |
-
#######
|
237 |
-
|
238 |
-
# alphabet = Alphabet(mask_prob = 0.0, standard_toks = 'AGCT')
|
239 |
-
alphabet = Alphabet(prepend_toks=("<pad>", "<eos>", "<unk>"), standard_toks = 'AGCT', append_toks=("<cls>", "<mask>", "<sep>"))
|
240 |
-
# print(alphabet.tok_to_idx)
|
241 |
-
# assert alphabet.tok_to_idx == {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
|
242 |
-
alphabet.tok_to_idx = {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
|
243 |
-
|
244 |
-
def predict_file(input_file):
|
245 |
-
print('====Load Data====')
|
246 |
-
ids, seqs = read_fasta(input_file)
|
247 |
-
_, dataloader = generate_dataset_dataloader(ids, seqs)
|
248 |
-
|
249 |
-
model = CNN_linear().to(device)
|
250 |
-
# model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=lambda storage, loc : storage.cuda(storage_id)).items()}, strict = False)
|
251 |
-
model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}, strict = False)
|
252 |
-
# model = DistributedDataParallel(model, device_ids=[device_ids[local_rank]], output_device=device_ids[local_rank], find_unused_parameters=True)
|
253 |
-
|
254 |
-
print('====Predict====')
|
255 |
-
pred = eval_step(dataloader, model)
|
256 |
-
return pred
|
257 |
-
# print(pred)
|
258 |
-
# print('====Save Results====')
|
259 |
-
# if not os.path.exists(args.outdir): os.makedirs(args.outdir)
|
260 |
-
# pred.to_csv(f'{args.outdir}/{args.outfilename}_prediction_results.csv', index = False)
|
261 |
-
|
262 |
-
def predict_raw(raw_input):
|
263 |
-
print('====Parse Input====')
|
264 |
-
ids, seqs = read_raw(raw_input)
|
265 |
-
_, dataloader = generate_dataset_dataloader(ids, seqs)
|
266 |
-
|
267 |
-
model = CNN_linear().to(device)
|
268 |
-
# model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=lambda storage, loc : storage.cuda(storage_id)).items()}, strict = False)
|
269 |
-
model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}, strict = False)
|
270 |
-
# model = DistributedDataParallel(model, device_ids=[device_ids[local_rank]], output_device=device_ids[local_rank], find_unused_parameters=True)
|
271 |
-
|
272 |
-
print('====Predict====')
|
273 |
-
pred = eval_step(dataloader, model)
|
274 |
-
|
275 |
-
# print(pred)
|
276 |
-
return pred
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -1,26 +1,178 @@
|
|
1 |
import streamlit as st
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
from io import StringIO
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
st.title("5' UTR prediction")
|
7 |
|
8 |
st.subheader("Input sequence")
|
9 |
-
|
10 |
-
|
11 |
-
seq = st.text_area("Input your sequence here", value="")
|
12 |
st.subheader("Upload sequence file")
|
13 |
uploaded = st.file_uploader("Sequence file in FASTA format")
|
14 |
-
|
15 |
-
# st.write(StringIO(uploaded.getvalue().decode("utf-8")))
|
16 |
-
# seq = uploaded.read()
|
17 |
-
# print(seq)
|
18 |
-
# predict_file(uploaded)
|
19 |
-
# seq = SeqIO.read(uploaded, "fasta").seq
|
20 |
-
# st.subheader("Prediction result:")
|
21 |
if st.button("Predict"):
|
22 |
-
st.write("Prediction result:")
|
23 |
if uploaded:
|
24 |
-
|
|
|
|
|
|
|
25 |
else:
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from Bio import SeqIO
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from esm.model.esm2 import ESM2 as ESM2_SISS
|
6 |
+
from esm import Alphabet, FastaBatchedDataset
|
7 |
+
import pandas as pd
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
from io import StringIO
|
11 |
|
12 |
+
seed = 1337
|
13 |
+
torch.manual_seed(seed)
|
14 |
+
|
15 |
+
global modelfile, layers, heads, embed_dim, batch_toks, inp_len, device
|
16 |
+
modelfile = 'model.pkl'
|
17 |
+
|
18 |
+
layers = 6
|
19 |
+
heads = 16
|
20 |
+
embed_dim = 128
|
21 |
+
batch_toks = 4096
|
22 |
+
|
23 |
+
inp_len = 50
|
24 |
+
|
25 |
+
device = "cpu"
|
26 |
+
|
27 |
+
alphabet = Alphabet(prepend_toks=("<pad>", "<eos>", "<unk>"), standard_toks = 'AGCT', append_toks=("<cls>", "<mask>", "<sep>"))
|
28 |
+
alphabet.tok_to_idx = {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
|
29 |
+
|
30 |
+
class CNN_linear(nn.Module):
|
31 |
+
def __init__(self,
|
32 |
+
border_mode='same', filter_len=8, nbr_filters=120,
|
33 |
+
dropout1=0, dropout2=0):
|
34 |
+
|
35 |
+
super(CNN_linear, self).__init__()
|
36 |
+
|
37 |
+
self.embedding_size = embed_dim
|
38 |
+
self.border_mode = border_mode
|
39 |
+
self.inp_len = inp_len
|
40 |
+
self.nodes = 40
|
41 |
+
self.cnn_layers = 0
|
42 |
+
self.filter_len = filter_len
|
43 |
+
self.nbr_filters = nbr_filters
|
44 |
+
self.dropout1 = dropout1
|
45 |
+
self.dropout2 = dropout2
|
46 |
+
self.dropout3 = 0.5
|
47 |
+
|
48 |
+
self.esm2 = ESM2_SISS(num_layers = layers,
|
49 |
+
embed_dim = embed_dim,
|
50 |
+
attention_heads = heads,
|
51 |
+
alphabet = alphabet)
|
52 |
+
|
53 |
+
self.conv1 = nn.Conv1d(in_channels = self.embedding_size,
|
54 |
+
out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode)
|
55 |
+
self.conv2 = nn.Conv1d(in_channels = self.nbr_filters,
|
56 |
+
out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode)
|
57 |
+
|
58 |
+
self.dropout1 = nn.Dropout(self.dropout1)
|
59 |
+
self.dropout2 = nn.Dropout(self.dropout2)
|
60 |
+
self.dropout3 = nn.Dropout(self.dropout3)
|
61 |
+
self.relu = nn.ReLU()
|
62 |
+
self.flatten = nn.Flatten()
|
63 |
+
self.fc = nn.Linear(in_features = embed_dim, out_features = self.nodes)
|
64 |
+
self.linear = nn.Linear(in_features = self.nbr_filters, out_features = self.nodes)
|
65 |
+
self.output = nn.Linear(in_features = self.nodes, out_features = 1)
|
66 |
+
self.direct_output = nn.Linear(in_features = embed_dim, out_features = 1)
|
67 |
+
self.magic_output = nn.Linear(in_features = 1, out_features = 1)
|
68 |
+
|
69 |
+
def forward(self, tokens, need_head_weights=True, return_contacts=False, return_representation=True):
|
70 |
+
|
71 |
+
# x = self.esm2(tokens, [layers], need_head_weights, return_contacts, return_representation)
|
72 |
+
x = self.esm2(tokens, [layers])
|
73 |
+
|
74 |
+
x = x["representations"][layers][:, 0]
|
75 |
+
x_o = x.unsqueeze(2)
|
76 |
+
|
77 |
+
x = self.flatten(x_o)
|
78 |
+
o_linear = self.fc(x)
|
79 |
+
o_relu = self.relu(o_linear)
|
80 |
+
o_dropout = self.dropout3(o_relu)
|
81 |
+
o = self.output(o_dropout)
|
82 |
+
return o
|
83 |
+
|
84 |
+
def eval_step(dataloader, model, threshold = 0.5):
|
85 |
+
model.eval()
|
86 |
+
logits_list= []
|
87 |
+
# y_pred_list, y_prob_list = [], []
|
88 |
+
ids_list, strs_list = [], []
|
89 |
+
my_bar = st.progress(0, text="Running UTR_LM")
|
90 |
+
with torch.no_grad():
|
91 |
+
# for (ids, strs, _, toks, _, _) in tqdm(dataloader):
|
92 |
+
for i, (ids, strs, toks) in enumerate(dataloader):
|
93 |
+
ids_list.extend(ids)
|
94 |
+
strs_list.extend(strs)
|
95 |
+
# toks = toks.to(device)
|
96 |
+
my_bar.progress((i+1)/len(dataloader), text="Running UTR_LM")
|
97 |
+
# print(toks)
|
98 |
+
logits = model(toks)
|
99 |
+
|
100 |
+
logits = logits.reshape(-1)
|
101 |
+
# y_prob = torch.sigmoid(logits)
|
102 |
+
# y_pred = (y_prob > threshold).long()
|
103 |
+
|
104 |
+
logits_list.extend(logits.tolist())
|
105 |
+
# y_prob_list.extend(y_prob.tolist())
|
106 |
+
# y_pred_list.extend(y_pred.tolist())
|
107 |
+
|
108 |
+
st.success('Done', icon="✅")
|
109 |
+
data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "Translation Efficiency":logits_list})
|
110 |
+
return data_pred
|
111 |
+
|
112 |
+
def read_raw(raw_input):
|
113 |
+
ids = []
|
114 |
+
sequences = []
|
115 |
+
|
116 |
+
file = StringIO(raw_input)
|
117 |
+
for record in SeqIO.parse(file, "fasta"):
|
118 |
+
|
119 |
+
# 检查序列是否只包含A, G, C, T
|
120 |
+
sequence = str(record.seq.back_transcribe()).upper()[-inp_len:]
|
121 |
+
if not set(sequence).issubset(set("AGCT")):
|
122 |
+
st.write(f"Record '{record.description}' was skipped for containing invalid characters. Only A, G, C, T(U) are allowed.")
|
123 |
+
continue
|
124 |
+
|
125 |
+
# 将符合条件的序列添加到列表中
|
126 |
+
ids.append(record.id)
|
127 |
+
sequences.append(sequence)
|
128 |
+
|
129 |
+
return ids, sequences
|
130 |
+
|
131 |
+
def generate_dataset_dataloader(ids, seqs):
|
132 |
+
# dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0)
|
133 |
+
dataset = FastaBatchedDataset(ids, seqs)
|
134 |
+
batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=2)
|
135 |
+
dataloader = torch.utils.data.DataLoader(dataset,
|
136 |
+
collate_fn=alphabet.get_batch_converter(),
|
137 |
+
batch_sampler=batches,
|
138 |
+
shuffle = False)
|
139 |
+
# dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=batches, shuffle = False)
|
140 |
+
st.write(f"{len(dataset)} sequences")
|
141 |
+
return dataset, dataloader
|
142 |
+
|
143 |
+
def predict_raw(raw_input):
|
144 |
+
# st.write('====Parse Input====')
|
145 |
+
ids, seqs = read_raw(raw_input)
|
146 |
+
_, dataloader = generate_dataset_dataloader(ids, seqs)
|
147 |
+
|
148 |
+
model = CNN_linear()
|
149 |
+
|
150 |
+
model.load_state_dict(torch.load(modelfile, map_location=torch.device('cpu')), strict = False)
|
151 |
+
|
152 |
+
# st.write('====Predict====')
|
153 |
+
pred = eval_step(dataloader, model)
|
154 |
+
|
155 |
+
# print(pred)
|
156 |
+
return pred
|
157 |
+
|
158 |
st.title("5' UTR prediction")
|
159 |
|
160 |
st.subheader("Input sequence")
|
161 |
+
|
162 |
+
seq = st.text_area("FASTA format only", value="")
|
|
|
163 |
st.subheader("Upload sequence file")
|
164 |
uploaded = st.file_uploader("Sequence file in FASTA format")
|
165 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
if st.button("Predict"):
|
|
|
167 |
if uploaded:
|
168 |
+
result = predict_raw(uploaded.getvalue().decode())
|
169 |
+
result_file = result.to_csv(index=False)
|
170 |
+
st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv")
|
171 |
+
st.dataframe(result)
|
172 |
else:
|
173 |
+
result = predict_raw(seq)
|
174 |
+
result_file = result.to_csv(index=False)
|
175 |
+
st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv")
|
176 |
+
st.dataframe(result)
|
177 |
+
|
178 |
+
|
requirements.txt
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
-
biopython
|
2 |
-
torch
|
3 |
numpy
|
4 |
pandas
|
5 |
tqdm
|
6 |
-
fair-esm
|
|
|
1 |
+
biopython==1.81
|
2 |
+
torch==2.0.1
|
3 |
numpy
|
4 |
pandas
|
5 |
tqdm
|
6 |
+
fair-esm @ git+https://github.com/facebookresearch/esm.git@900251ba3e2b7cdc06b44b10dfa3a0c1dd49752b
|