sagawa commited on
Commit
a0912bb
1 Parent(s): 7cee862

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -30
app.py CHANGED
@@ -28,50 +28,65 @@ class Config:
28
  self.seed = 42
29
 
30
 
31
-
32
  def predict_stability_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()):
33
- results = {"file_name": [],
34
- "raw prediction value": [],
35
- "binary prediction value": []
36
- }
 
37
  file_names = []
38
  input_sequences = []
39
 
 
40
  for pdb_file in pdb_files:
41
  pdb_path = pdb_file.name
42
- os.system("chmod 777 bin/foldseek")
43
  sequences = get_foldseek_seq(pdb_path)
 
 
44
  if not sequences:
45
- results["file_name"].append(pdb_file.name.split("/")[-1])
46
  results["raw prediction value"].append(None)
47
  results["binary prediction value"].append(None)
48
  continue
49
 
50
  sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
51
- file_names.append(pdb_file.name.split("/")[-1])
52
  input_sequences.append(sequence)
53
 
54
- raw_prediction, binary_prediction = predict_stability_core(model_choice, organism_choice, input_sequences, cfg)
55
- results["file_name"] = results["file_name"] + file_names
56
- results["raw prediction value"] = results["raw prediction value"] + raw_prediction
57
- results["binary prediction value"] = results["binary prediction value"] + binary_prediction
58
-
 
 
59
  df = pd.DataFrame(results)
60
  output_csv = "/tmp/predictions.csv"
61
  df.to_csv(output_csv, index=False)
62
 
63
  return output_csv
64
 
65
- def predict_stability_with_sequence(model_choice, organism_choice, sequence, cfg=Config()):
 
 
 
 
 
66
  try:
67
- if not sequence:
68
- return "No valid sequence provided."
69
- raw_prediction, binary_prediction = predict_stability_core(model_choice, organism_choice, [sequence], cfg)
70
- df = pd.DataFrame({"sequence": sequence, "raw prediction value": raw_prediction, "binary prediction value": binary_prediction})
 
 
 
 
 
 
71
  output_csv = "/tmp/predictions.csv"
72
  df.to_csv(output_csv, index=False)
73
 
74
- return output_csv
75
  except Exception as e:
76
  return f"An error occurred: {str(e)}"
77
 
@@ -110,6 +125,7 @@ def predict(cfg, sequences):
110
  cfg.model_path, padding_side=cfg.padding_side
111
  )
112
  cfg.tokenizer = tokenizer
 
113
  dataset = PLTNUMDataset(cfg, df, train=False)
114
  dataloader = DataLoader(
115
  dataset,
@@ -126,19 +142,19 @@ def predict(cfg, sequences):
126
  model.eval()
127
  predictions = []
128
 
129
- for inputs, _ in dataloader:
130
- inputs = inputs.to(cfg.device)
131
- with torch.no_grad():
132
  with torch.amp.autocast(cfg.device, enabled=cfg.use_amp):
133
  preds = (
134
  torch.sigmoid(model(inputs))
135
  if cfg.task == "classification"
136
  else model(inputs)
137
  )
138
- predictions += preds.cpu().tolist()
139
 
140
  predictions = list(itertools.chain.from_iterable(predictions))
141
-
142
  return predictions, [1 if x > 0.5 else 0 for x in predictions]
143
 
144
 
@@ -174,9 +190,7 @@ with gr.Blocks() as demo:
174
  gr.Markdown("### Upload your PDB files:")
175
  pdb_files = gr.File(label="Upload PDB Files", file_count="multiple")
176
  predict_button = gr.Button("Predict Stability")
177
- prediction_output = gr.File(
178
- label="Download Predictions"
179
- )
180
 
181
  predict_button.click(
182
  fn=predict_stability_with_pdb,
@@ -192,9 +206,7 @@ with gr.Blocks() as demo:
192
  lines=8,
193
  )
194
  predict_button = gr.Button("Predict Stability")
195
- prediction_output = gr.File(
196
- label="Download Predictions"
197
- )
198
 
199
  predict_button.click(
200
  fn=predict_stability_with_sequence,
 
28
  self.seed = 42
29
 
30
 
 
31
  def predict_stability_with_pdb(model_choice, organism_choice, pdb_files, cfg=Config()):
32
+ results = {
33
+ "file_name": [],
34
+ "raw prediction value": [],
35
+ "binary prediction value": [],
36
+ }
37
  file_names = []
38
  input_sequences = []
39
 
40
+ os.system("chmod 777 bin/foldseek")
41
  for pdb_file in pdb_files:
42
  pdb_path = pdb_file.name
 
43
  sequences = get_foldseek_seq(pdb_path)
44
+
45
+ file_name = os.path.basename(pdb_path)
46
  if not sequences:
47
+ results["file_name"].append(file_name)
48
  results["raw prediction value"].append(None)
49
  results["binary prediction value"].append(None)
50
  continue
51
 
52
  sequence = sequences[2] if model_choice == "SaProt" else sequences[0]
53
+ file_names.append(file_name)
54
  input_sequences.append(sequence)
55
 
56
+ raw_pred, binary_pred = predict_stability_core(
57
+ model_choice, organism_choice, input_sequences, cfg
58
+ )
59
+ results["file_name"].extend(file_names)
60
+ results["raw prediction value"].extend(raw_pred)
61
+ results["binary prediction value"].extend(binary_pred)
62
+
63
  df = pd.DataFrame(results)
64
  output_csv = "/tmp/predictions.csv"
65
  df.to_csv(output_csv, index=False)
66
 
67
  return output_csv
68
 
69
+
70
+ def predict_stability_with_sequence(
71
+ model_choice, organism_choice, sequence, cfg=Config()
72
+ ):
73
+ if not sequence:
74
+ return "No valid sequence provided."
75
  try:
76
+ raw_pred, binary_pred = predict_stability_core(
77
+ model_choice, organism_choice, [sequence], cfg
78
+ )
79
+ df = pd.DataFrame(
80
+ {
81
+ "sequence": sequence,
82
+ "raw prediction value": raw_pred,
83
+ "binary prediction value": binary_pred,
84
+ }
85
+ )
86
  output_csv = "/tmp/predictions.csv"
87
  df.to_csv(output_csv, index=False)
88
 
89
+ return output_csv
90
  except Exception as e:
91
  return f"An error occurred: {str(e)}"
92
 
 
125
  cfg.model_path, padding_side=cfg.padding_side
126
  )
127
  cfg.tokenizer = tokenizer
128
+
129
  dataset = PLTNUMDataset(cfg, df, train=False)
130
  dataloader = DataLoader(
131
  dataset,
 
142
  model.eval()
143
  predictions = []
144
 
145
+ with torch.no_grad():
146
+ for inputs, _ in dataloader:
147
+ inputs = inputs.to(cfg.device)
148
  with torch.amp.autocast(cfg.device, enabled=cfg.use_amp):
149
  preds = (
150
  torch.sigmoid(model(inputs))
151
  if cfg.task == "classification"
152
  else model(inputs)
153
  )
154
+ predictions.extend(preds.cpu().tolist())
155
 
156
  predictions = list(itertools.chain.from_iterable(predictions))
157
+
158
  return predictions, [1 if x > 0.5 else 0 for x in predictions]
159
 
160
 
 
190
  gr.Markdown("### Upload your PDB files:")
191
  pdb_files = gr.File(label="Upload PDB Files", file_count="multiple")
192
  predict_button = gr.Button("Predict Stability")
193
+ prediction_output = gr.File(label="Download Predictions")
 
 
194
 
195
  predict_button.click(
196
  fn=predict_stability_with_pdb,
 
206
  lines=8,
207
  )
208
  predict_button = gr.Button("Predict Stability")
209
+ prediction_output = gr.File(label="Download Predictions")
 
 
210
 
211
  predict_button.click(
212
  fn=predict_stability_with_sequence,