Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -21,13 +21,13 @@ modelfile = 'model.pt'
|
|
21 |
layers = 6
|
22 |
heads = 16
|
23 |
embed_dim = 128
|
24 |
-
batch_toks =
|
25 |
|
26 |
inp_len = 50
|
27 |
|
28 |
device = "cpu"
|
29 |
|
30 |
-
alphabet = Alphabet(standard_toks = 'AGCT')
|
31 |
assert alphabet.tok_to_idx == {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
|
32 |
|
33 |
class CNN_linear(nn.Module):
|
@@ -97,7 +97,7 @@ def eval_step(dataloader, model, threshold=0.5):
|
|
97 |
# toks = toks.to(device)
|
98 |
my_bar.progress((i+1)/len(dataloader), text="Running UTR_LM")
|
99 |
# print(toks)
|
100 |
-
logits = model(toks)
|
101 |
|
102 |
logits = logits.reshape(-1)
|
103 |
# y_prob = torch.sigmoid(logits)
|
@@ -108,8 +108,8 @@ def eval_step(dataloader, model, threshold=0.5):
|
|
108 |
# y_pred_list.extend(y_pred.tolist())
|
109 |
|
110 |
st.success('Done', icon="✅")
|
111 |
-
# data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "
|
112 |
-
data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "
|
113 |
return data_pred
|
114 |
|
115 |
def read_raw(raw_input):
|
@@ -135,7 +135,7 @@ def generate_dataset_dataloader(ids, seqs):
|
|
135 |
dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0)
|
136 |
|
137 |
# dataset = FastaBatchedDataset(ids, seqs)
|
138 |
-
batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=
|
139 |
dataloader = torch.utils.data.DataLoader(dataset,
|
140 |
collate_fn=alphabet.get_batch_converter(),
|
141 |
batch_sampler=batches,
|
@@ -150,8 +150,10 @@ def predict_raw(raw_input):
|
|
150 |
_, dataloader = generate_dataset_dataloader(ids, seqs)
|
151 |
|
152 |
model = CNN_linear()
|
153 |
-
|
154 |
-
|
|
|
|
|
155 |
|
156 |
# st.write('====Predict====')
|
157 |
pred = eval_step(dataloader, model)
|
|
|
21 |
layers = 6
|
22 |
heads = 16
|
23 |
embed_dim = 128
|
24 |
+
batch_toks = 4096
|
25 |
|
26 |
inp_len = 50
|
27 |
|
28 |
device = "cpu"
|
29 |
|
30 |
+
alphabet = Alphabet(mask_prob = 0.0, standard_toks = 'AGCT')
|
31 |
assert alphabet.tok_to_idx == {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
|
32 |
|
33 |
class CNN_linear(nn.Module):
|
|
|
97 |
# toks = toks.to(device)
|
98 |
my_bar.progress((i+1)/len(dataloader), text="Running UTR_LM")
|
99 |
# print(toks)
|
100 |
+
logits = model(toks, return_representation = True, return_contacts=True)
|
101 |
|
102 |
logits = logits.reshape(-1)
|
103 |
# y_prob = torch.sigmoid(logits)
|
|
|
108 |
# y_pred_list.extend(y_pred.tolist())
|
109 |
|
110 |
st.success('Done', icon="✅")
|
111 |
+
# data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "MRL":logits_list, "prob":y_prob_list, "pred":y_pred_list})
|
112 |
+
data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "MRL":logits_list})
|
113 |
return data_pred
|
114 |
|
115 |
def read_raw(raw_input):
|
|
|
135 |
dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0)
|
136 |
|
137 |
# dataset = FastaBatchedDataset(ids, seqs)
|
138 |
+
batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=1)
|
139 |
dataloader = torch.utils.data.DataLoader(dataset,
|
140 |
collate_fn=alphabet.get_batch_converter(),
|
141 |
batch_sampler=batches,
|
|
|
150 |
_, dataloader = generate_dataset_dataloader(ids, seqs)
|
151 |
|
152 |
model = CNN_linear()
|
153 |
+
print(model.state_dict().keys())
|
154 |
+
print(torch.load(modelfile, map_location=torch.device('cpu')).keys())
|
155 |
+
model.esm2.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}, strict = False)
|
156 |
+
# model.load_state_dict(torch.load(modelfile, map_location=torch.device('cpu')), strict = False)
|
157 |
|
158 |
# st.write('====Predict====')
|
159 |
pred = eval_step(dataloader, model)
|