Shawn Shen commited on
Commit
3aa4b4a
1 Parent(s): 7714123

consistent result

Browse files
Files changed (3) hide show
  1. Predictor.py +0 -276
  2. app.py +167 -15
  3. requirements.txt +3 -3
Predictor.py DELETED
@@ -1,276 +0,0 @@
1
- # CUDA_VISIBLE_DEVICES=2 python -m torch.distributed.launch --nproc_per_node=1 --master_port 3303 Predictor.py --predict_file /home/ubuntu/Experimental_Data/v1_5UTR_seqs_with_v1Label.fasta --outdir /home/ubuntu/Experimental_Data/try --outfilename try_RVACv1
2
-
3
-
4
- import os
5
- from Bio import SeqIO
6
- import sys
7
-
8
- # import argparse
9
- # from argparse import Namespace
10
- # import pathlib
11
-
12
- import torch
13
- import torch.nn as nn
14
- import torch.nn.functional as F
15
-
16
- # import esm
17
- # from esm.data import *
18
- # from esm.model.esm2_secondarystructure import ESM2 as ESM2_SISS
19
- from esm.model.esm2 import ESM2 as ESM2_SISS
20
- # from esm.model.esm2_supervised import ESM2
21
- from esm import Alphabet, FastaBatchedDataset#, ProteinBertModel, pretrained, MSATransformer
22
-
23
-
24
- import numpy as np
25
- import pandas as pd
26
- import random
27
- # import math
28
- # import scipy.stats as stats
29
- # from scipy.stats import spearmanr, pearsonr
30
- # from sklearn import preprocessing
31
- # from copy import deepcopy
32
- from tqdm import tqdm#, trange
33
- # import matplotlib.pyplot as plt
34
- # import seaborn as sns
35
- # from sklearn.model_selection import KFold
36
- # from torch.optim.lr_scheduler import StepLR
37
- # import torch.distributed as dist
38
- # from torch.nn.parallel import DistributedDataParallel
39
- # from torch.utils.data.distributed import DistributedSampler
40
- from io import StringIO
41
-
42
- seed = 1337
43
- random.seed(seed)
44
- np.random.seed(seed)
45
- torch.manual_seed(seed)
46
- # torch.cuda.manual_seed(seed)
47
- # torch.cuda.manual_seed_all(seed)
48
-
49
- # parser = argparse.ArgumentParser()
50
- # parser.add_argument('--device_ids', type=str, default='0', help="Training Devices")
51
- # parser.add_argument('--local-rank', type=int, default=-1, help="DDP parameter, do not modify")
52
-
53
- # parser.add_argument('--outdir', type=str, default = '/home/ubuntu/Experimental_Data/try')
54
- # parser.add_argument('--outfilename', type=str, default = 'try_RVACv1')
55
- # parser.add_argument('--predict_file', type = str, default = '/home/ubuntu/Experimental_Data/v1_5UTR_seqs_with_v1Label.fasta')
56
- # args = parser.parse_args()
57
- # print(args)
58
-
59
- global modelfile, layers, heads, embed_dim, batch_toks, inp_len, device_ids, device
60
- modelfile = 'model.pkl'
61
-
62
- # model_info = modelfile.split('/')[-1].split('_')
63
- # for item in model_info:
64
- # if 'layers' in item:
65
- # layers = int(item[0])
66
- # elif 'heads' in item:
67
- # heads = int(item[:-5])
68
- # elif 'embedsize' in item:
69
- # embed_dim = int(item[:-9])
70
- # elif 'batchToks' in item:
71
- # batch_toks = 4096
72
-
73
- layers = 6
74
- heads = 16
75
- embed_dim = 128
76
- batch_toks = 4096
77
-
78
- inp_len = 50
79
-
80
- # device_ids = list(map(int, args.device_ids.split(',')))
81
- # dist.init_process_group(backend='nccl')
82
- # device = torch.device('cuda:{}'.format(device_ids[args.local_rank]))
83
- device = "cpu"
84
- # torch.cuda.set_device(device)
85
-
86
- # local_rank = args.local_rank
87
- local_rank = -1
88
- # storage_id = int(device_ids[local_rank])
89
- storage_id = 0
90
-
91
- # repr_layers = [layers]
92
- include = ["mean"]
93
-
94
- class CNN_linear(nn.Module):
95
- def __init__(self,
96
- border_mode='same', filter_len=8, nbr_filters=120,
97
- dropout1=0, dropout2=0):
98
-
99
- super(CNN_linear, self).__init__()
100
-
101
- self.embedding_size = embed_dim
102
- self.border_mode = border_mode
103
- self.inp_len = inp_len
104
- self.nodes = 40
105
- self.cnn_layers = 0
106
- self.filter_len = filter_len
107
- self.nbr_filters = nbr_filters
108
- self.dropout1 = dropout1
109
- self.dropout2 = dropout2
110
- self.dropout3 = 0.5
111
-
112
- self.esm2 = ESM2_SISS(num_layers = layers,
113
- embed_dim = embed_dim,
114
- attention_heads = heads,
115
- alphabet = alphabet)
116
-
117
- self.conv1 = nn.Conv1d(in_channels = self.embedding_size,
118
- out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode)
119
- self.conv2 = nn.Conv1d(in_channels = self.nbr_filters,
120
- out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode)
121
-
122
- self.dropout1 = nn.Dropout(self.dropout1)
123
- self.dropout2 = nn.Dropout(self.dropout2)
124
- self.dropout3 = nn.Dropout(self.dropout3)
125
- self.relu = nn.ReLU()
126
- self.flatten = nn.Flatten()
127
- self.fc = nn.Linear(in_features = embed_dim, out_features = self.nodes)
128
- self.linear = nn.Linear(in_features = self.nbr_filters, out_features = self.nodes)
129
- self.output = nn.Linear(in_features = self.nodes, out_features = 1)
130
- self.direct_output = nn.Linear(in_features = embed_dim, out_features = 1)
131
- self.magic_output = nn.Linear(in_features = 1, out_features = 1)
132
-
133
- def forward(self, tokens, need_head_weights=True, return_contacts=False, return_representation=True):
134
-
135
- # x = self.esm2(tokens, [layers], need_head_weights, return_contacts, return_representation)
136
- x = self.esm2(tokens, [layers])
137
-
138
- x = x["representations"][layers][:, 0]
139
- x_o = x.unsqueeze(2)
140
-
141
- x = self.flatten(x_o)
142
- o_linear = self.fc(x)
143
- o_relu = self.relu(o_linear)
144
- o_dropout = self.dropout3(o_relu)
145
- o = self.output(o_dropout)
146
- return o
147
-
148
- def eval_step(dataloader, model, threshold = 0.5):
149
- model.eval()
150
- y_pred_list, y_prob_list = [], []
151
- ids_list, strs_list = [], []
152
- with torch.no_grad():
153
- # for (ids, strs, _, toks, _, _) in tqdm(dataloader):
154
- for ids, strs, toks in tqdm(dataloader):
155
- ids_list.extend(ids)
156
- strs_list.extend(strs)
157
- # toks = toks.to(device)
158
-
159
- # print(toks)
160
- logits = model(toks)
161
-
162
- logits = logits.reshape(-1)
163
- y_prob = torch.sigmoid(logits)
164
- y_pred = (y_prob > threshold).long()
165
-
166
-
167
- y_prob_list.extend(y_prob.cpu().detach().tolist())
168
- y_pred_list.extend(y_pred.cpu().detach().tolist())
169
-
170
- data_pred = pd.DataFrame([ids_list, strs_list, y_prob_list, y_pred_list], index = ['ID', 'Sequence', "Probability as 5'UTR", "Prediction as 5'UTR"]).T
171
- return data_pred
172
-
173
-
174
-
175
- def generate_dataset_dataloader(ids, seqs):
176
- # dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0)
177
- dataset = FastaBatchedDataset(ids, seqs)
178
- batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=2)
179
- dataloader = torch.utils.data.DataLoader(dataset,
180
- collate_fn=alphabet.get_batch_converter(),
181
- batch_sampler=batches,
182
- shuffle = False)
183
- print(f"{len(dataset)} sequences")
184
- return dataset, dataloader
185
-
186
- def read_fasta(file):
187
- # 判断文件是否为空
188
- if os.path.getsize(file) == 0:
189
- print("Error: The file is empty!")
190
- sys.exit()
191
-
192
- ids = []
193
- sequences = []
194
-
195
- for record in SeqIO.parse(file, "fasta"):
196
- # 检查序列的开头是否为">"
197
- # if not record.id.startswith('>'):
198
- # print(f"Error: The sequence '{record.id}' is not properly formatted, it does not start with '>'. Skipping...")
199
- # continue
200
-
201
- # 检查序列是否只包含A, G, C, T
202
- sequence = str(record.seq).upper()[-inp_len:]
203
- if not set(sequence).issubset(set("AGCT")):
204
- print(f"Error: The sequence '{record.description}' contains invalid characters. Only A, G, C, T are allowed. Skipping...")
205
- continue
206
-
207
- # 将符合条件的序列添加到列表中
208
- ids.append(record.id)
209
- sequences.append(sequence)
210
-
211
- return ids, sequences
212
-
213
- def read_raw(raw_input):
214
- ids = []
215
- sequences = []
216
-
217
- file = StringIO(raw_input)
218
- for record in SeqIO.parse(file, "fasta"):
219
- # 检查序列的开头是否为">"
220
- # if not record.id.startswith('>'):
221
- # print(f"Error: The sequence '{record.id}' is not properly formatted, it does not start with '>'. Skipping...")
222
- # continue
223
-
224
- # 检查序列是否只包含A, G, C, T
225
- sequence = str(record.seq).upper()[-inp_len:]
226
- if not set(sequence).issubset(set("AGCT")):
227
- print(f"Error: The sequence '{record.description}' contains invalid characters. Only A, G, C, T are allowed. Skipping...")
228
- continue
229
-
230
- # 将符合条件的序列添加到列表中
231
- ids.append(record.id)
232
- sequences.append(sequence)
233
-
234
- return ids, sequences
235
-
236
- #######
237
-
238
- # alphabet = Alphabet(mask_prob = 0.0, standard_toks = 'AGCT')
239
- alphabet = Alphabet(prepend_toks=("<pad>", "<eos>", "<unk>"), standard_toks = 'AGCT', append_toks=("<cls>", "<mask>", "<sep>"))
240
- # print(alphabet.tok_to_idx)
241
- # assert alphabet.tok_to_idx == {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
242
- alphabet.tok_to_idx = {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
243
-
244
- def predict_file(input_file):
245
- print('====Load Data====')
246
- ids, seqs = read_fasta(input_file)
247
- _, dataloader = generate_dataset_dataloader(ids, seqs)
248
-
249
- model = CNN_linear().to(device)
250
- # model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=lambda storage, loc : storage.cuda(storage_id)).items()}, strict = False)
251
- model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}, strict = False)
252
- # model = DistributedDataParallel(model, device_ids=[device_ids[local_rank]], output_device=device_ids[local_rank], find_unused_parameters=True)
253
-
254
- print('====Predict====')
255
- pred = eval_step(dataloader, model)
256
- return pred
257
- # print(pred)
258
- # print('====Save Results====')
259
- # if not os.path.exists(args.outdir): os.makedirs(args.outdir)
260
- # pred.to_csv(f'{args.outdir}/{args.outfilename}_prediction_results.csv', index = False)
261
-
262
- def predict_raw(raw_input):
263
- print('====Parse Input====')
264
- ids, seqs = read_raw(raw_input)
265
- _, dataloader = generate_dataset_dataloader(ids, seqs)
266
-
267
- model = CNN_linear().to(device)
268
- # model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=lambda storage, loc : storage.cuda(storage_id)).items()}, strict = False)
269
- model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}, strict = False)
270
- # model = DistributedDataParallel(model, device_ids=[device_ids[local_rank]], output_device=device_ids[local_rank], find_unused_parameters=True)
271
-
272
- print('====Predict====')
273
- pred = eval_step(dataloader, model)
274
-
275
- # print(pred)
276
- return pred
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,26 +1,178 @@
1
  import streamlit as st
2
- # from Bio import SeqIO
3
- from Predictor import predict_file, predict_raw
 
 
 
 
 
 
4
  from io import StringIO
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  st.title("5' UTR prediction")
7
 
8
  st.subheader("Input sequence")
9
- #x = st.slider('Select a value')
10
- # seq = ""
11
- seq = st.text_area("Input your sequence here", value="")
12
  st.subheader("Upload sequence file")
13
  uploaded = st.file_uploader("Sequence file in FASTA format")
14
- # if uploaded:
15
- # st.write(StringIO(uploaded.getvalue().decode("utf-8")))
16
- # seq = uploaded.read()
17
- # print(seq)
18
- # predict_file(uploaded)
19
- # seq = SeqIO.read(uploaded, "fasta").seq
20
- # st.subheader("Prediction result:")
21
  if st.button("Predict"):
22
- st.write("Prediction result:")
23
  if uploaded:
24
- st.write(predict_raw(uploaded.getvalue().decode()))
 
 
 
25
  else:
26
- st.write(predict_raw(seq))
 
 
 
 
 
 
1
  import streamlit as st
2
+ from Bio import SeqIO
3
+ import torch
4
+ import torch.nn as nn
5
+ from esm.model.esm2 import ESM2 as ESM2_SISS
6
+ from esm import Alphabet, FastaBatchedDataset
7
+ import pandas as pd
8
+ from tqdm import tqdm
9
+
10
  from io import StringIO
11
 
12
+ seed = 1337
13
+ torch.manual_seed(seed)
14
+
15
+ global modelfile, layers, heads, embed_dim, batch_toks, inp_len, device
16
+ modelfile = 'model.pkl'
17
+
18
+ layers = 6
19
+ heads = 16
20
+ embed_dim = 128
21
+ batch_toks = 4096
22
+
23
+ inp_len = 50
24
+
25
+ device = "cpu"
26
+
27
+ alphabet = Alphabet(prepend_toks=("<pad>", "<eos>", "<unk>"), standard_toks = 'AGCT', append_toks=("<cls>", "<mask>", "<sep>"))
28
+ alphabet.tok_to_idx = {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
29
+
30
+ class CNN_linear(nn.Module):
31
+ def __init__(self,
32
+ border_mode='same', filter_len=8, nbr_filters=120,
33
+ dropout1=0, dropout2=0):
34
+
35
+ super(CNN_linear, self).__init__()
36
+
37
+ self.embedding_size = embed_dim
38
+ self.border_mode = border_mode
39
+ self.inp_len = inp_len
40
+ self.nodes = 40
41
+ self.cnn_layers = 0
42
+ self.filter_len = filter_len
43
+ self.nbr_filters = nbr_filters
44
+ self.dropout1 = dropout1
45
+ self.dropout2 = dropout2
46
+ self.dropout3 = 0.5
47
+
48
+ self.esm2 = ESM2_SISS(num_layers = layers,
49
+ embed_dim = embed_dim,
50
+ attention_heads = heads,
51
+ alphabet = alphabet)
52
+
53
+ self.conv1 = nn.Conv1d(in_channels = self.embedding_size,
54
+ out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode)
55
+ self.conv2 = nn.Conv1d(in_channels = self.nbr_filters,
56
+ out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode)
57
+
58
+ self.dropout1 = nn.Dropout(self.dropout1)
59
+ self.dropout2 = nn.Dropout(self.dropout2)
60
+ self.dropout3 = nn.Dropout(self.dropout3)
61
+ self.relu = nn.ReLU()
62
+ self.flatten = nn.Flatten()
63
+ self.fc = nn.Linear(in_features = embed_dim, out_features = self.nodes)
64
+ self.linear = nn.Linear(in_features = self.nbr_filters, out_features = self.nodes)
65
+ self.output = nn.Linear(in_features = self.nodes, out_features = 1)
66
+ self.direct_output = nn.Linear(in_features = embed_dim, out_features = 1)
67
+ self.magic_output = nn.Linear(in_features = 1, out_features = 1)
68
+
69
+ def forward(self, tokens, need_head_weights=True, return_contacts=False, return_representation=True):
70
+
71
+ # x = self.esm2(tokens, [layers], need_head_weights, return_contacts, return_representation)
72
+ x = self.esm2(tokens, [layers])
73
+
74
+ x = x["representations"][layers][:, 0]
75
+ x_o = x.unsqueeze(2)
76
+
77
+ x = self.flatten(x_o)
78
+ o_linear = self.fc(x)
79
+ o_relu = self.relu(o_linear)
80
+ o_dropout = self.dropout3(o_relu)
81
+ o = self.output(o_dropout)
82
+ return o
83
+
84
+ def eval_step(dataloader, model, threshold = 0.5):
85
+ model.eval()
86
+ logits_list= []
87
+ # y_pred_list, y_prob_list = [], []
88
+ ids_list, strs_list = [], []
89
+ my_bar = st.progress(0, text="Running UTR_LM")
90
+ with torch.no_grad():
91
+ # for (ids, strs, _, toks, _, _) in tqdm(dataloader):
92
+ for i, (ids, strs, toks) in enumerate(dataloader):
93
+ ids_list.extend(ids)
94
+ strs_list.extend(strs)
95
+ # toks = toks.to(device)
96
+ my_bar.progress((i+1)/len(dataloader), text="Running UTR_LM")
97
+ # print(toks)
98
+ logits = model(toks)
99
+
100
+ logits = logits.reshape(-1)
101
+ # y_prob = torch.sigmoid(logits)
102
+ # y_pred = (y_prob > threshold).long()
103
+
104
+ logits_list.extend(logits.tolist())
105
+ # y_prob_list.extend(y_prob.tolist())
106
+ # y_pred_list.extend(y_pred.tolist())
107
+
108
+ st.success('Done', icon="✅")
109
+ data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "Translation Efficiency":logits_list})
110
+ return data_pred
111
+
112
+ def read_raw(raw_input):
113
+ ids = []
114
+ sequences = []
115
+
116
+ file = StringIO(raw_input)
117
+ for record in SeqIO.parse(file, "fasta"):
118
+
119
+ # 检查序列是否只包含A, G, C, T
120
+ sequence = str(record.seq.back_transcribe()).upper()[-inp_len:]
121
+ if not set(sequence).issubset(set("AGCT")):
122
+ st.write(f"Record '{record.description}' was skipped for containing invalid characters. Only A, G, C, T(U) are allowed.")
123
+ continue
124
+
125
+ # 将符合条件的序列添加到列表中
126
+ ids.append(record.id)
127
+ sequences.append(sequence)
128
+
129
+ return ids, sequences
130
+
131
+ def generate_dataset_dataloader(ids, seqs):
132
+ # dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0)
133
+ dataset = FastaBatchedDataset(ids, seqs)
134
+ batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=2)
135
+ dataloader = torch.utils.data.DataLoader(dataset,
136
+ collate_fn=alphabet.get_batch_converter(),
137
+ batch_sampler=batches,
138
+ shuffle = False)
139
+ # dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=batches, shuffle = False)
140
+ st.write(f"{len(dataset)} sequences")
141
+ return dataset, dataloader
142
+
143
+ def predict_raw(raw_input):
144
+ # st.write('====Parse Input====')
145
+ ids, seqs = read_raw(raw_input)
146
+ _, dataloader = generate_dataset_dataloader(ids, seqs)
147
+
148
+ model = CNN_linear()
149
+
150
+ model.load_state_dict(torch.load(modelfile, map_location=torch.device('cpu')), strict = False)
151
+
152
+ # st.write('====Predict====')
153
+ pred = eval_step(dataloader, model)
154
+
155
+ # print(pred)
156
+ return pred
157
+
158
  st.title("5' UTR prediction")
159
 
160
  st.subheader("Input sequence")
161
+
162
+ seq = st.text_area("FASTA format only", value="")
 
163
  st.subheader("Upload sequence file")
164
  uploaded = st.file_uploader("Sequence file in FASTA format")
165
+
 
 
 
 
 
 
166
  if st.button("Predict"):
 
167
  if uploaded:
168
+ result = predict_raw(uploaded.getvalue().decode())
169
+ result_file = result.to_csv(index=False)
170
+ st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv")
171
+ st.dataframe(result)
172
  else:
173
+ result = predict_raw(seq)
174
+ result_file = result.to_csv(index=False)
175
+ st.download_button("Download", result_file, file_name="UTR_LM_prediction.csv")
176
+ st.dataframe(result)
177
+
178
+
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
- biopython
2
- torch
3
  numpy
4
  pandas
5
  tqdm
6
- fair-esm
 
1
+ biopython==1.81
2
+ torch==2.0.1
3
  numpy
4
  pandas
5
  tqdm
6
+ fair-esm @ git+https://github.com/facebookresearch/esm.git@900251ba3e2b7cdc06b44b10dfa3a0c1dd49752b