Spaces:
Runtime error
Runtime error
Shawn37
commited on
Commit
•
a6b0878
1
Parent(s):
5f3a2c7
finished frame
Browse files- .gitattributes +1 -0
- Predictor.py +275 -0
- app.py +22 -0
- model.pkl +3 -0
- requirements.txt +5 -0
.gitattributes
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
model.pkl filter=lfs diff=lfs merge=lfs -text
|
Predictor.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# CUDA_VISIBLE_DEVICES=2 python -m torch.distributed.launch --nproc_per_node=1 --master_port 3303 Predictor.py --predict_file /home/ubuntu/Experimental_Data/v1_5UTR_seqs_with_v1Label.fasta --outdir /home/ubuntu/Experimental_Data/try --outfilename try_RVACv1
|
2 |
+
|
3 |
+
|
4 |
+
import os
|
5 |
+
from Bio import SeqIO
|
6 |
+
import sys
|
7 |
+
|
8 |
+
# import argparse
|
9 |
+
# from argparse import Namespace
|
10 |
+
# import pathlib
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
import torch.nn.functional as F
|
15 |
+
|
16 |
+
# import esm
|
17 |
+
# from esm.data import *
|
18 |
+
# from esm.model.esm2_secondarystructure import ESM2 as ESM2_SISS
|
19 |
+
from esm.model.esm2 import ESM2 as ESM2_SISS
|
20 |
+
# from esm.model.esm2_supervised import ESM2
|
21 |
+
from esm import Alphabet, FastaBatchedDataset#, ProteinBertModel, pretrained, MSATransformer
|
22 |
+
|
23 |
+
|
24 |
+
import numpy as np
|
25 |
+
import pandas as pd
|
26 |
+
import random
|
27 |
+
# import math
|
28 |
+
# import scipy.stats as stats
|
29 |
+
# from scipy.stats import spearmanr, pearsonr
|
30 |
+
# from sklearn import preprocessing
|
31 |
+
# from copy import deepcopy
|
32 |
+
from tqdm import tqdm#, trange
|
33 |
+
# import matplotlib.pyplot as plt
|
34 |
+
# import seaborn as sns
|
35 |
+
# from sklearn.model_selection import KFold
|
36 |
+
# from torch.optim.lr_scheduler import StepLR
|
37 |
+
# import torch.distributed as dist
|
38 |
+
# from torch.nn.parallel import DistributedDataParallel
|
39 |
+
# from torch.utils.data.distributed import DistributedSampler
|
40 |
+
from io import StringIO
|
41 |
+
|
42 |
+
seed = 1337
|
43 |
+
random.seed(seed)
|
44 |
+
np.random.seed(seed)
|
45 |
+
torch.manual_seed(seed)
|
46 |
+
# torch.cuda.manual_seed(seed)
|
47 |
+
# torch.cuda.manual_seed_all(seed)
|
48 |
+
|
49 |
+
# parser = argparse.ArgumentParser()
|
50 |
+
# parser.add_argument('--device_ids', type=str, default='0', help="Training Devices")
|
51 |
+
# parser.add_argument('--local-rank', type=int, default=-1, help="DDP parameter, do not modify")
|
52 |
+
|
53 |
+
# parser.add_argument('--outdir', type=str, default = '/home/ubuntu/Experimental_Data/try')
|
54 |
+
# parser.add_argument('--outfilename', type=str, default = 'try_RVACv1')
|
55 |
+
# parser.add_argument('--predict_file', type = str, default = '/home/ubuntu/Experimental_Data/v1_5UTR_seqs_with_v1Label.fasta')
|
56 |
+
# args = parser.parse_args()
|
57 |
+
# print(args)
|
58 |
+
|
59 |
+
global modelfile, layers, heads, embed_dim, batch_toks, inp_len, device_ids, device
|
60 |
+
modelfile = 'model.pkl'
|
61 |
+
|
62 |
+
# model_info = modelfile.split('/')[-1].split('_')
|
63 |
+
# for item in model_info:
|
64 |
+
# if 'layers' in item:
|
65 |
+
# layers = int(item[0])
|
66 |
+
# elif 'heads' in item:
|
67 |
+
# heads = int(item[:-5])
|
68 |
+
# elif 'embedsize' in item:
|
69 |
+
# embed_dim = int(item[:-9])
|
70 |
+
# elif 'batchToks' in item:
|
71 |
+
# batch_toks = 4096
|
72 |
+
|
73 |
+
layers = 6
|
74 |
+
heads = 16
|
75 |
+
embed_dim = 128
|
76 |
+
batch_toks = 4096
|
77 |
+
|
78 |
+
inp_len = 50
|
79 |
+
|
80 |
+
# device_ids = list(map(int, args.device_ids.split(',')))
|
81 |
+
# dist.init_process_group(backend='nccl')
|
82 |
+
# device = torch.device('cuda:{}'.format(device_ids[args.local_rank]))
|
83 |
+
device = "cpu"
|
84 |
+
# torch.cuda.set_device(device)
|
85 |
+
|
86 |
+
# local_rank = args.local_rank
|
87 |
+
local_rank = -1
|
88 |
+
# storage_id = int(device_ids[local_rank])
|
89 |
+
storage_id = 0
|
90 |
+
|
91 |
+
# repr_layers = [layers]
|
92 |
+
include = ["mean"]
|
93 |
+
|
94 |
+
class CNN_linear(nn.Module):
|
95 |
+
def __init__(self,
|
96 |
+
border_mode='same', filter_len=8, nbr_filters=120,
|
97 |
+
dropout1=0, dropout2=0):
|
98 |
+
|
99 |
+
super(CNN_linear, self).__init__()
|
100 |
+
|
101 |
+
self.embedding_size = embed_dim
|
102 |
+
self.border_mode = border_mode
|
103 |
+
self.inp_len = inp_len
|
104 |
+
self.nodes = 40
|
105 |
+
self.cnn_layers = 0
|
106 |
+
self.filter_len = filter_len
|
107 |
+
self.nbr_filters = nbr_filters
|
108 |
+
self.dropout1 = dropout1
|
109 |
+
self.dropout2 = dropout2
|
110 |
+
self.dropout3 = 0.5
|
111 |
+
|
112 |
+
self.esm2 = ESM2_SISS(num_layers = layers,
|
113 |
+
embed_dim = embed_dim,
|
114 |
+
attention_heads = heads,
|
115 |
+
alphabet = alphabet)
|
116 |
+
|
117 |
+
self.conv1 = nn.Conv1d(in_channels = self.embedding_size,
|
118 |
+
out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode)
|
119 |
+
self.conv2 = nn.Conv1d(in_channels = self.nbr_filters,
|
120 |
+
out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode)
|
121 |
+
|
122 |
+
self.dropout1 = nn.Dropout(self.dropout1)
|
123 |
+
self.dropout2 = nn.Dropout(self.dropout2)
|
124 |
+
self.dropout3 = nn.Dropout(self.dropout3)
|
125 |
+
self.relu = nn.ReLU()
|
126 |
+
self.flatten = nn.Flatten()
|
127 |
+
self.fc = nn.Linear(in_features = embed_dim, out_features = self.nodes)
|
128 |
+
self.linear = nn.Linear(in_features = self.nbr_filters, out_features = self.nodes)
|
129 |
+
self.output = nn.Linear(in_features = self.nodes, out_features = 1)
|
130 |
+
self.direct_output = nn.Linear(in_features = embed_dim, out_features = 1)
|
131 |
+
self.magic_output = nn.Linear(in_features = 1, out_features = 1)
|
132 |
+
|
133 |
+
def forward(self, tokens, need_head_weights=True, return_contacts=False, return_representation=True):
|
134 |
+
|
135 |
+
# x = self.esm2(tokens, [layers], need_head_weights, return_contacts, return_representation)
|
136 |
+
x = self.esm2(tokens, [layers])
|
137 |
+
|
138 |
+
x = x["representations"][layers][:, 0]
|
139 |
+
x_o = x.unsqueeze(2)
|
140 |
+
|
141 |
+
x = self.flatten(x_o)
|
142 |
+
o_linear = self.fc(x)
|
143 |
+
o_relu = self.relu(o_linear)
|
144 |
+
o_dropout = self.dropout3(o_relu)
|
145 |
+
o = self.output(o_dropout)
|
146 |
+
return o
|
147 |
+
|
148 |
+
def eval_step(dataloader, model, threshold = 0.5):
|
149 |
+
model.eval()
|
150 |
+
y_pred_list, y_prob_list = [], []
|
151 |
+
ids_list, strs_list = [], []
|
152 |
+
with torch.no_grad():
|
153 |
+
# for (ids, strs, _, toks, _, _) in tqdm(dataloader):
|
154 |
+
for ids, strs, toks in tqdm(dataloader):
|
155 |
+
ids_list.extend(ids)
|
156 |
+
strs_list.extend(strs)
|
157 |
+
# toks = toks.to(device)
|
158 |
+
|
159 |
+
# print(toks)
|
160 |
+
logits = model(toks)
|
161 |
+
|
162 |
+
logits = logits.reshape(-1)
|
163 |
+
y_prob = torch.sigmoid(logits)
|
164 |
+
y_pred = (y_prob > threshold).long()
|
165 |
+
|
166 |
+
|
167 |
+
y_prob_list.extend(y_prob.cpu().detach().tolist())
|
168 |
+
y_pred_list.extend(y_pred.cpu().detach().tolist())
|
169 |
+
|
170 |
+
data_pred = pd.DataFrame([ids_list, strs_list, y_prob_list, y_pred_list], index = ['ID', 'Sequence', "Probability as 5'UTR", "Prediction as 5'UTR"]).T
|
171 |
+
return data_pred
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
def generate_dataset_dataloader(ids, seqs):
|
176 |
+
# dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0)
|
177 |
+
dataset = FastaBatchedDataset(ids, seqs)
|
178 |
+
batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=2)
|
179 |
+
dataloader = torch.utils.data.DataLoader(dataset,
|
180 |
+
collate_fn=alphabet.get_batch_converter(),
|
181 |
+
batch_sampler=batches,
|
182 |
+
shuffle = False)
|
183 |
+
print(f"{len(dataset)} sequences")
|
184 |
+
return dataset, dataloader
|
185 |
+
|
186 |
+
def read_fasta(file):
|
187 |
+
# 判断文件是否为空
|
188 |
+
if os.path.getsize(file) == 0:
|
189 |
+
print("Error: The file is empty!")
|
190 |
+
sys.exit()
|
191 |
+
|
192 |
+
ids = []
|
193 |
+
sequences = []
|
194 |
+
|
195 |
+
for record in SeqIO.parse(file, "fasta"):
|
196 |
+
# 检查序列的开头是否为">"
|
197 |
+
# if not record.id.startswith('>'):
|
198 |
+
# print(f"Error: The sequence '{record.id}' is not properly formatted, it does not start with '>'. Skipping...")
|
199 |
+
# continue
|
200 |
+
|
201 |
+
# 检查序列是否只包含A, G, C, T
|
202 |
+
sequence = str(record.seq).upper()[-inp_len:]
|
203 |
+
if not set(sequence).issubset(set("AGCT")):
|
204 |
+
print(f"Error: The sequence '{record.description}' contains invalid characters. Only A, G, C, T are allowed. Skipping...")
|
205 |
+
continue
|
206 |
+
|
207 |
+
# 将符合条件的序列添加到列表中
|
208 |
+
ids.append(record.id)
|
209 |
+
sequences.append(sequence)
|
210 |
+
|
211 |
+
return ids, sequences
|
212 |
+
|
213 |
+
def read_raw(raw_input):
|
214 |
+
ids = []
|
215 |
+
sequences = []
|
216 |
+
|
217 |
+
file = StringIO(raw_input)
|
218 |
+
for record in SeqIO.parse(file, "fasta"):
|
219 |
+
# 检查序列的开头是否为">"
|
220 |
+
# if not record.id.startswith('>'):
|
221 |
+
# print(f"Error: The sequence '{record.id}' is not properly formatted, it does not start with '>'. Skipping...")
|
222 |
+
# continue
|
223 |
+
|
224 |
+
# 检查序列是否只包含A, G, C, T
|
225 |
+
sequence = str(record.seq).upper()[-inp_len:]
|
226 |
+
if not set(sequence).issubset(set("AGCT")):
|
227 |
+
print(f"Error: The sequence '{record.description}' contains invalid characters. Only A, G, C, T are allowed. Skipping...")
|
228 |
+
continue
|
229 |
+
|
230 |
+
# 将符合条件的序列添加到列表中
|
231 |
+
ids.append(record.id)
|
232 |
+
sequences.append(sequence)
|
233 |
+
|
234 |
+
return ids, sequences
|
235 |
+
|
236 |
+
#######
|
237 |
+
|
238 |
+
# alphabet = Alphabet(mask_prob = 0.0, standard_toks = 'AGCT')
|
239 |
+
alphabet = Alphabet(prepend_toks=("<pad>", "<eos>", "<unk>"), standard_toks = 'AGCT', append_toks=("<cls>", "<mask>", "<sep>"))
|
240 |
+
# print(alphabet.tok_to_idx)
|
241 |
+
# assert alphabet.tok_to_idx == {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
|
242 |
+
alphabet.tok_to_idx = {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
|
243 |
+
|
244 |
+
def predict_file(input_file):
|
245 |
+
print('====Load Data====')
|
246 |
+
ids, seqs = read_fasta(input_file)
|
247 |
+
_, dataloader = generate_dataset_dataloader(ids, seqs)
|
248 |
+
|
249 |
+
model = CNN_linear().to(device)
|
250 |
+
# model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=lambda storage, loc : storage.cuda(storage_id)).items()}, strict = False)
|
251 |
+
model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}, strict = False)
|
252 |
+
# model = DistributedDataParallel(model, device_ids=[device_ids[local_rank]], output_device=device_ids[local_rank], find_unused_parameters=True)
|
253 |
+
|
254 |
+
print('====Predict====')
|
255 |
+
pred = eval_step(dataloader, model)
|
256 |
+
|
257 |
+
print(pred)
|
258 |
+
# print('====Save Results====')
|
259 |
+
# if not os.path.exists(args.outdir): os.makedirs(args.outdir)
|
260 |
+
# pred.to_csv(f'{args.outdir}/{args.outfilename}_prediction_results.csv', index = False)
|
261 |
+
|
262 |
+
def predict_raw(raw_input):
|
263 |
+
print('====Parse Input====')
|
264 |
+
ids, seqs = read_raw(raw_input)
|
265 |
+
_, dataloader = generate_dataset_dataloader(ids, seqs)
|
266 |
+
|
267 |
+
model = CNN_linear().to(device)
|
268 |
+
# model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=lambda storage, loc : storage.cuda(storage_id)).items()}, strict = False)
|
269 |
+
model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}, strict = False)
|
270 |
+
# model = DistributedDataParallel(model, device_ids=[device_ids[local_rank]], output_device=device_ids[local_rank], find_unused_parameters=True)
|
271 |
+
|
272 |
+
print('====Predict====')
|
273 |
+
pred = eval_step(dataloader, model)
|
274 |
+
|
275 |
+
print(pred)
|
app.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from Bio import SeqIO
|
3 |
+
from Predictor import predict_file, predict_raw
|
4 |
+
|
5 |
+
st.title("5' UTR prediction")
|
6 |
+
|
7 |
+
st.subheader("Input sequence")
|
8 |
+
#x = st.slider('Select a value')
|
9 |
+
# seq = ""
|
10 |
+
seq = st.text_input("Input your sequence here", value="")
|
11 |
+
st.subheader("Upload sequence file")
|
12 |
+
uploaded = st.file_uploader("Sequence file in FASTA format")
|
13 |
+
# if uploaded:
|
14 |
+
# predict_file(uploaded)
|
15 |
+
# seq = SeqIO.read(uploaded, "fasta").seq
|
16 |
+
st.subheader("Prediction result:")
|
17 |
+
if st.button("Predict"):
|
18 |
+
if uploaded:
|
19 |
+
predict_file(uploaded)
|
20 |
+
else:
|
21 |
+
predict_raw(seq)
|
22 |
+
# st.write("Sequence length = ", len(seq))
|
model.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:705ea278849702e12285d4059dc15d902cc445f458729415ed83b1bb6515f3d3
|
3 |
+
size 4905915
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
biopython
|
2 |
+
pytorch
|
3 |
+
numpy
|
4 |
+
pandas
|
5 |
+
tqdm
|