Shawn37 commited on
Commit
a6b0878
1 Parent(s): 5f3a2c7

finished frame

Browse files
Files changed (5) hide show
  1. .gitattributes +1 -0
  2. Predictor.py +275 -0
  3. app.py +22 -0
  4. model.pkl +3 -0
  5. requirements.txt +5 -0
.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ model.pkl filter=lfs diff=lfs merge=lfs -text
Predictor.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CUDA_VISIBLE_DEVICES=2 python -m torch.distributed.launch --nproc_per_node=1 --master_port 3303 Predictor.py --predict_file /home/ubuntu/Experimental_Data/v1_5UTR_seqs_with_v1Label.fasta --outdir /home/ubuntu/Experimental_Data/try --outfilename try_RVACv1
2
+
3
+
4
+ import os
5
+ from Bio import SeqIO
6
+ import sys
7
+
8
+ # import argparse
9
+ # from argparse import Namespace
10
+ # import pathlib
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+
16
+ # import esm
17
+ # from esm.data import *
18
+ # from esm.model.esm2_secondarystructure import ESM2 as ESM2_SISS
19
+ from esm.model.esm2 import ESM2 as ESM2_SISS
20
+ # from esm.model.esm2_supervised import ESM2
21
+ from esm import Alphabet, FastaBatchedDataset#, ProteinBertModel, pretrained, MSATransformer
22
+
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+ import random
27
+ # import math
28
+ # import scipy.stats as stats
29
+ # from scipy.stats import spearmanr, pearsonr
30
+ # from sklearn import preprocessing
31
+ # from copy import deepcopy
32
+ from tqdm import tqdm#, trange
33
+ # import matplotlib.pyplot as plt
34
+ # import seaborn as sns
35
+ # from sklearn.model_selection import KFold
36
+ # from torch.optim.lr_scheduler import StepLR
37
+ # import torch.distributed as dist
38
+ # from torch.nn.parallel import DistributedDataParallel
39
+ # from torch.utils.data.distributed import DistributedSampler
40
+ from io import StringIO
41
+
42
+ seed = 1337
43
+ random.seed(seed)
44
+ np.random.seed(seed)
45
+ torch.manual_seed(seed)
46
+ # torch.cuda.manual_seed(seed)
47
+ # torch.cuda.manual_seed_all(seed)
48
+
49
+ # parser = argparse.ArgumentParser()
50
+ # parser.add_argument('--device_ids', type=str, default='0', help="Training Devices")
51
+ # parser.add_argument('--local-rank', type=int, default=-1, help="DDP parameter, do not modify")
52
+
53
+ # parser.add_argument('--outdir', type=str, default = '/home/ubuntu/Experimental_Data/try')
54
+ # parser.add_argument('--outfilename', type=str, default = 'try_RVACv1')
55
+ # parser.add_argument('--predict_file', type = str, default = '/home/ubuntu/Experimental_Data/v1_5UTR_seqs_with_v1Label.fasta')
56
+ # args = parser.parse_args()
57
+ # print(args)
58
+
59
+ global modelfile, layers, heads, embed_dim, batch_toks, inp_len, device_ids, device
60
+ modelfile = 'model.pkl'
61
+
62
+ # model_info = modelfile.split('/')[-1].split('_')
63
+ # for item in model_info:
64
+ # if 'layers' in item:
65
+ # layers = int(item[0])
66
+ # elif 'heads' in item:
67
+ # heads = int(item[:-5])
68
+ # elif 'embedsize' in item:
69
+ # embed_dim = int(item[:-9])
70
+ # elif 'batchToks' in item:
71
+ # batch_toks = 4096
72
+
73
+ layers = 6
74
+ heads = 16
75
+ embed_dim = 128
76
+ batch_toks = 4096
77
+
78
+ inp_len = 50
79
+
80
+ # device_ids = list(map(int, args.device_ids.split(',')))
81
+ # dist.init_process_group(backend='nccl')
82
+ # device = torch.device('cuda:{}'.format(device_ids[args.local_rank]))
83
+ device = "cpu"
84
+ # torch.cuda.set_device(device)
85
+
86
+ # local_rank = args.local_rank
87
+ local_rank = -1
88
+ # storage_id = int(device_ids[local_rank])
89
+ storage_id = 0
90
+
91
+ # repr_layers = [layers]
92
+ include = ["mean"]
93
+
94
+ class CNN_linear(nn.Module):
95
+ def __init__(self,
96
+ border_mode='same', filter_len=8, nbr_filters=120,
97
+ dropout1=0, dropout2=0):
98
+
99
+ super(CNN_linear, self).__init__()
100
+
101
+ self.embedding_size = embed_dim
102
+ self.border_mode = border_mode
103
+ self.inp_len = inp_len
104
+ self.nodes = 40
105
+ self.cnn_layers = 0
106
+ self.filter_len = filter_len
107
+ self.nbr_filters = nbr_filters
108
+ self.dropout1 = dropout1
109
+ self.dropout2 = dropout2
110
+ self.dropout3 = 0.5
111
+
112
+ self.esm2 = ESM2_SISS(num_layers = layers,
113
+ embed_dim = embed_dim,
114
+ attention_heads = heads,
115
+ alphabet = alphabet)
116
+
117
+ self.conv1 = nn.Conv1d(in_channels = self.embedding_size,
118
+ out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode)
119
+ self.conv2 = nn.Conv1d(in_channels = self.nbr_filters,
120
+ out_channels = self.nbr_filters, kernel_size = self.filter_len, padding = self.border_mode)
121
+
122
+ self.dropout1 = nn.Dropout(self.dropout1)
123
+ self.dropout2 = nn.Dropout(self.dropout2)
124
+ self.dropout3 = nn.Dropout(self.dropout3)
125
+ self.relu = nn.ReLU()
126
+ self.flatten = nn.Flatten()
127
+ self.fc = nn.Linear(in_features = embed_dim, out_features = self.nodes)
128
+ self.linear = nn.Linear(in_features = self.nbr_filters, out_features = self.nodes)
129
+ self.output = nn.Linear(in_features = self.nodes, out_features = 1)
130
+ self.direct_output = nn.Linear(in_features = embed_dim, out_features = 1)
131
+ self.magic_output = nn.Linear(in_features = 1, out_features = 1)
132
+
133
+ def forward(self, tokens, need_head_weights=True, return_contacts=False, return_representation=True):
134
+
135
+ # x = self.esm2(tokens, [layers], need_head_weights, return_contacts, return_representation)
136
+ x = self.esm2(tokens, [layers])
137
+
138
+ x = x["representations"][layers][:, 0]
139
+ x_o = x.unsqueeze(2)
140
+
141
+ x = self.flatten(x_o)
142
+ o_linear = self.fc(x)
143
+ o_relu = self.relu(o_linear)
144
+ o_dropout = self.dropout3(o_relu)
145
+ o = self.output(o_dropout)
146
+ return o
147
+
148
+ def eval_step(dataloader, model, threshold = 0.5):
149
+ model.eval()
150
+ y_pred_list, y_prob_list = [], []
151
+ ids_list, strs_list = [], []
152
+ with torch.no_grad():
153
+ # for (ids, strs, _, toks, _, _) in tqdm(dataloader):
154
+ for ids, strs, toks in tqdm(dataloader):
155
+ ids_list.extend(ids)
156
+ strs_list.extend(strs)
157
+ # toks = toks.to(device)
158
+
159
+ # print(toks)
160
+ logits = model(toks)
161
+
162
+ logits = logits.reshape(-1)
163
+ y_prob = torch.sigmoid(logits)
164
+ y_pred = (y_prob > threshold).long()
165
+
166
+
167
+ y_prob_list.extend(y_prob.cpu().detach().tolist())
168
+ y_pred_list.extend(y_pred.cpu().detach().tolist())
169
+
170
+ data_pred = pd.DataFrame([ids_list, strs_list, y_prob_list, y_pred_list], index = ['ID', 'Sequence', "Probability as 5'UTR", "Prediction as 5'UTR"]).T
171
+ return data_pred
172
+
173
+
174
+
175
+ def generate_dataset_dataloader(ids, seqs):
176
+ # dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0)
177
+ dataset = FastaBatchedDataset(ids, seqs)
178
+ batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=2)
179
+ dataloader = torch.utils.data.DataLoader(dataset,
180
+ collate_fn=alphabet.get_batch_converter(),
181
+ batch_sampler=batches,
182
+ shuffle = False)
183
+ print(f"{len(dataset)} sequences")
184
+ return dataset, dataloader
185
+
186
+ def read_fasta(file):
187
+ # 判断文件是否为空
188
+ if os.path.getsize(file) == 0:
189
+ print("Error: The file is empty!")
190
+ sys.exit()
191
+
192
+ ids = []
193
+ sequences = []
194
+
195
+ for record in SeqIO.parse(file, "fasta"):
196
+ # 检查序列的开头是否为">"
197
+ # if not record.id.startswith('>'):
198
+ # print(f"Error: The sequence '{record.id}' is not properly formatted, it does not start with '>'. Skipping...")
199
+ # continue
200
+
201
+ # 检查序列是否只包含A, G, C, T
202
+ sequence = str(record.seq).upper()[-inp_len:]
203
+ if not set(sequence).issubset(set("AGCT")):
204
+ print(f"Error: The sequence '{record.description}' contains invalid characters. Only A, G, C, T are allowed. Skipping...")
205
+ continue
206
+
207
+ # 将符合条件的序列添加到列表中
208
+ ids.append(record.id)
209
+ sequences.append(sequence)
210
+
211
+ return ids, sequences
212
+
213
+ def read_raw(raw_input):
214
+ ids = []
215
+ sequences = []
216
+
217
+ file = StringIO(raw_input)
218
+ for record in SeqIO.parse(file, "fasta"):
219
+ # 检查序列的开头是否为">"
220
+ # if not record.id.startswith('>'):
221
+ # print(f"Error: The sequence '{record.id}' is not properly formatted, it does not start with '>'. Skipping...")
222
+ # continue
223
+
224
+ # 检查序列是否只包含A, G, C, T
225
+ sequence = str(record.seq).upper()[-inp_len:]
226
+ if not set(sequence).issubset(set("AGCT")):
227
+ print(f"Error: The sequence '{record.description}' contains invalid characters. Only A, G, C, T are allowed. Skipping...")
228
+ continue
229
+
230
+ # 将符合条件的序列添加到列表中
231
+ ids.append(record.id)
232
+ sequences.append(sequence)
233
+
234
+ return ids, sequences
235
+
236
+ #######
237
+
238
+ # alphabet = Alphabet(mask_prob = 0.0, standard_toks = 'AGCT')
239
+ alphabet = Alphabet(prepend_toks=("<pad>", "<eos>", "<unk>"), standard_toks = 'AGCT', append_toks=("<cls>", "<mask>", "<sep>"))
240
+ # print(alphabet.tok_to_idx)
241
+ # assert alphabet.tok_to_idx == {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
242
+ alphabet.tok_to_idx = {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
243
+
244
+ def predict_file(input_file):
245
+ print('====Load Data====')
246
+ ids, seqs = read_fasta(input_file)
247
+ _, dataloader = generate_dataset_dataloader(ids, seqs)
248
+
249
+ model = CNN_linear().to(device)
250
+ # model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=lambda storage, loc : storage.cuda(storage_id)).items()}, strict = False)
251
+ model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}, strict = False)
252
+ # model = DistributedDataParallel(model, device_ids=[device_ids[local_rank]], output_device=device_ids[local_rank], find_unused_parameters=True)
253
+
254
+ print('====Predict====')
255
+ pred = eval_step(dataloader, model)
256
+
257
+ print(pred)
258
+ # print('====Save Results====')
259
+ # if not os.path.exists(args.outdir): os.makedirs(args.outdir)
260
+ # pred.to_csv(f'{args.outdir}/{args.outfilename}_prediction_results.csv', index = False)
261
+
262
+ def predict_raw(raw_input):
263
+ print('====Parse Input====')
264
+ ids, seqs = read_raw(raw_input)
265
+ _, dataloader = generate_dataset_dataloader(ids, seqs)
266
+
267
+ model = CNN_linear().to(device)
268
+ # model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=lambda storage, loc : storage.cuda(storage_id)).items()}, strict = False)
269
+ model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}, strict = False)
270
+ # model = DistributedDataParallel(model, device_ids=[device_ids[local_rank]], output_device=device_ids[local_rank], find_unused_parameters=True)
271
+
272
+ print('====Predict====')
273
+ pred = eval_step(dataloader, model)
274
+
275
+ print(pred)
app.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from Bio import SeqIO
3
+ from Predictor import predict_file, predict_raw
4
+
5
+ st.title("5' UTR prediction")
6
+
7
+ st.subheader("Input sequence")
8
+ #x = st.slider('Select a value')
9
+ # seq = ""
10
+ seq = st.text_input("Input your sequence here", value="")
11
+ st.subheader("Upload sequence file")
12
+ uploaded = st.file_uploader("Sequence file in FASTA format")
13
+ # if uploaded:
14
+ # predict_file(uploaded)
15
+ # seq = SeqIO.read(uploaded, "fasta").seq
16
+ st.subheader("Prediction result:")
17
+ if st.button("Predict"):
18
+ if uploaded:
19
+ predict_file(uploaded)
20
+ else:
21
+ predict_raw(seq)
22
+ # st.write("Sequence length = ", len(seq))
model.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:705ea278849702e12285d4059dc15d902cc445f458729415ed83b1bb6515f3d3
3
+ size 4905915
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ biopython
2
+ pytorch
3
+ numpy
4
+ pandas
5
+ tqdm