yanyichu commited on
Commit
68d40a6
1 Parent(s): 8e2cf76

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -4
app.py CHANGED
@@ -66,8 +66,6 @@ class CNN_linear(nn.Module):
66
  self.fc = nn.Linear(in_features = embed_dim, out_features = self.nodes)
67
  self.linear = nn.Linear(in_features = self.nbr_filters, out_features = self.nodes)
68
  self.output = nn.Linear(in_features = self.nodes, out_features = 1)
69
- self.direct_output = nn.Linear(in_features = embed_dim, out_features = 1)
70
- self.magic_output = nn.Linear(in_features = 1, out_features = 1)
71
 
72
  def forward(self, tokens, need_head_weights=True, return_contacts=False, return_representation=True):
73
 
@@ -151,8 +149,8 @@ def predict_raw(raw_input):
151
 
152
  model = CNN_linear()
153
  st.write(model.state_dict().keys())
154
- st.write(torch.load(modelfile, map_location=torch.device('cpu')).keys())
155
- model.esm2.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}, strict = False)
156
  # model.load_state_dict(torch.load(modelfile, map_location=torch.device('cpu')), strict = False)
157
 
158
  # st.write('====Predict====')
 
66
  self.fc = nn.Linear(in_features = embed_dim, out_features = self.nodes)
67
  self.linear = nn.Linear(in_features = self.nbr_filters, out_features = self.nodes)
68
  self.output = nn.Linear(in_features = self.nodes, out_features = 1)
 
 
69
 
70
  def forward(self, tokens, need_head_weights=True, return_contacts=False, return_representation=True):
71
 
 
149
 
150
  model = CNN_linear()
151
  st.write(model.state_dict().keys())
152
+ st.write({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}.keys())
153
+ model.load_state_dict({k.replace('module.', ''):v for k,v in torch.load(modelfile, map_location=torch.device('cpu')).items()}, strict = True)
154
  # model.load_state_dict(torch.load(modelfile, map_location=torch.device('cpu')), strict = False)
155
 
156
  # st.write('====Predict====')