yanyichu commited on
Commit
711fba0
1 Parent(s): b4a1072

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -8
app.py CHANGED
@@ -21,13 +21,13 @@ modelfile = 'model.pt'
21
  layers = 6
22
  heads = 16
23
  embed_dim = 128
24
- batch_toks = 1024
25
 
26
  inp_len = 50
27
 
28
  device = "cpu"
29
 
30
- alphabet = Alphabet(standard_toks = 'AGCT')
31
  assert alphabet.tok_to_idx == {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
32
 
33
  class CNN_linear(nn.Module):
@@ -97,7 +97,7 @@ def eval_step(dataloader, model, threshold=0.5):
97
  # toks = toks.to(device)
98
  my_bar.progress((i+1)/len(dataloader), text="Running UTR_LM")
99
  # print(toks)
100
- logits = model(toks)
101
 
102
  logits = logits.reshape(-1)
103
  # y_prob = torch.sigmoid(logits)
@@ -108,8 +108,8 @@ def eval_step(dataloader, model, threshold=0.5):
108
  # y_pred_list.extend(y_pred.tolist())
109
 
110
  st.success('Done', icon="✅")
111
- # data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "Translation Efficiency":logits_list, "prob":y_prob_list, "pred":y_pred_list})
112
- data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "Translation Efficiency":logits_list})
113
  return data_pred
114
 
115
  def read_raw(raw_input):
@@ -135,7 +135,7 @@ def generate_dataset_dataloader(ids, seqs):
135
  dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0)
136
 
137
  # dataset = FastaBatchedDataset(ids, seqs)
138
- batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=2)
139
  dataloader = torch.utils.data.DataLoader(dataset,
140
  collate_fn=alphabet.get_batch_converter(),
141
  batch_sampler=batches,
@@ -150,8 +150,10 @@ def predict_raw(raw_input):
150
  _, dataloader = generate_dataset_dataloader(ids, seqs)
151
 
152
  model = CNN_linear()
153
-
154
- model.load_state_dict(torch.load(modelfile, map_location=torch.device('cpu')), strict = False)
 
 
155
 
156
  # st.write('====Predict====')
157
  pred = eval_step(dataloader, model)
 
21
  layers = 6
22
  heads = 16
23
  embed_dim = 128
24
+ batch_toks = 4096
25
 
26
  inp_len = 50
27
 
28
  device = "cpu"
29
 
30
+ alphabet = Alphabet(mask_prob = 0.0, standard_toks = 'AGCT')
31
  assert alphabet.tok_to_idx == {'<pad>': 0, '<eos>': 1, '<unk>': 2, 'A': 3, 'G': 4, 'C': 5, 'T': 6, '<cls>': 7, '<mask>': 8, '<sep>': 9}
32
 
33
  class CNN_linear(nn.Module):
 
97
  # toks = toks.to(device)
98
  my_bar.progress((i+1)/len(dataloader), text="Running UTR_LM")
99
  # print(toks)
100
+ logits = model(toks, return_representation = True, return_contacts=True)
101
 
102
  logits = logits.reshape(-1)
103
  # y_prob = torch.sigmoid(logits)
 
108
  # y_pred_list.extend(y_pred.tolist())
109
 
110
  st.success('Done', icon="✅")
111
+ # data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "MRL":logits_list, "prob":y_prob_list, "pred":y_pred_list})
112
+ data_pred = pd.DataFrame({'ID':ids_list, 'Sequence':strs_list, "MRL":logits_list})
113
  return data_pred
114
 
115
  def read_raw(raw_input):
 
135
  dataset = FastaBatchedDataset(ids, seqs, mask_prob = 0.0)
136
 
137
  # dataset = FastaBatchedDataset(ids, seqs)
138
+ batches = dataset.get_batch_indices(toks_per_batch=batch_toks, extra_toks_per_seq=1)
139
  dataloader = torch.utils.data.DataLoader(dataset,
140
  collate_fn=alphabet.get_batch_converter(),
141
  batch_sampler=batches,
 
150
  _, dataloader = generate_dataset_dataloader(ids, seqs)
151
 
152
  model = CNN_linear()
153
+ print(model.state_dict().keys())
154
+ print(torch.load(modelfile, map_location=torch.device('cpu')).keys())
155
+ model.esm2.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}, strict = False)
156
+ # model.load_state_dict(torch.load(modelfile, map_location=torch.device('cpu')), strict = False)
157
 
158
  # st.write('====Predict====')
159
  pred = eval_step(dataloader, model)