sagawa commited on
Commit
e4d81ca
1 Parent(s): 1e1820d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -84
app.py CHANGED
@@ -1,9 +1,9 @@
1
- import gradio as gr
2
  import sys
3
  import random
4
  import os
5
  import pandas as pd
6
  import torch
 
7
  from torch.utils.data import DataLoader
8
  from transformers import AutoTokenizer
9
 
@@ -23,6 +23,7 @@ class Config:
23
  padding_side = "right"
24
  task = "classification"
25
  sequence_col = "sequence"
 
26
 
27
 
28
  # Assuming 'predict_stability' is your function that predicts protein stability
@@ -91,13 +92,14 @@ def predict(cfg, sequence):
91
  model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg)
92
  model.to(cfg.device)
93
 
 
94
  model.eval()
95
  predictions = []
96
 
97
  for inputs, _ in dataloader:
98
  inputs = inputs.to(cfg.device)
99
  with torch.no_grad():
100
- with torch.amp.autocast(enabled=cfg.use_amp):
101
  preds = (
102
  torch.sigmoid(model(inputs))
103
  if cfg.task == "classification"
@@ -105,88 +107,9 @@ def predict(cfg, sequence):
105
  )
106
  predictions += preds.cpu().tolist()
107
  outputs = {}
 
108
  outputs["raw prediction values"] = predictions
109
  outputs["binary prediction values"] = [1 if x > 0.5 else 0 for x in predictions]
110
- return str(outputs)
111
 
112
-
113
- # Gradio Interface
114
- with gr.Blocks() as demo:
115
- gr.Markdown(
116
- """
117
- # PLTNUM: Protein LifeTime Neural Model
118
- **Predict the protein half-life from its sequence or PDB file.**
119
- """
120
- )
121
-
122
- gr.Image(
123
- "https://github.com/sagawatatsuya/PLTNUM/blob/main/model-image.png?raw=true",
124
- label="Model Image",
125
- )
126
-
127
- # Model and Organism selection in the same row to avoid layout issues
128
- with gr.Row():
129
- model_choice = gr.Radio(
130
- choices=["SaProt", "ESM2"],
131
- label="Select PLTNUM's base model.",
132
- value="SaProt",
133
- )
134
- organism_choice = gr.Radio(
135
- choices=["Mouse", "Human"],
136
- label="Select the target organism.",
137
- value="Mouse",
138
- )
139
-
140
- with gr.Tabs():
141
- with gr.TabItem("Upload PDB File"):
142
- gr.Markdown("### Upload your PDB file:")
143
- pdb_file = gr.File(label="Upload PDB File")
144
-
145
- predict_button = gr.Button("Predict Stability")
146
- prediction_output = gr.Textbox(
147
- label="Stability Prediction", interactive=False
148
- )
149
-
150
- predict_button.click(
151
- fn=predict_stability,
152
- inputs=[model_choice, organism_choice, pdb_file],
153
- outputs=prediction_output,
154
- )
155
-
156
- with gr.TabItem("Enter Protein Sequence"):
157
- gr.Markdown("### Enter the protein sequence:")
158
- sequence = gr.Textbox(
159
- label="Protein Sequence",
160
- placeholder="Enter your protein sequence here...",
161
- lines=8,
162
- )
163
- predict_button = gr.Button("Predict Stability")
164
- prediction_output = gr.Textbox(
165
- label="Stability Prediction", interactive=False
166
- )
167
-
168
- predict_button.click(
169
- fn=predict_stability,
170
- inputs=[model_choice, organism_choice, sequence],
171
- outputs=prediction_output,
172
- )
173
-
174
- gr.Markdown(
175
- """
176
- ### How to Use:
177
- - **Select Model**: Choose between 'SaProt' or 'ESM2' for your prediction.
178
- - **Select Organism**: Choose between 'Mouse' or 'Human'.
179
- - **Upload PDB File**: Choose the 'Upload PDB File' tab and upload your file.
180
- - **Enter Sequence**: Alternatively, switch to the 'Enter Protein Sequence' tab and input your sequence.
181
- - **Predict**: Click 'Predict Stability' to receive the prediction.
182
- """
183
- )
184
-
185
- gr.Markdown(
186
- """
187
- ### About the Tool
188
- This tool allows researchers and scientists to predict the stability of proteins using advanced algorithms. It supports both PDB file uploads and direct sequence input.
189
- """
190
- )
191
-
192
- demo.launch()
 
 
1
  import sys
2
  import random
3
  import os
4
  import pandas as pd
5
  import torch
6
+ import itertools
7
  from torch.utils.data import DataLoader
8
  from transformers import AutoTokenizer
9
 
 
23
  padding_side = "right"
24
  task = "classification"
25
  sequence_col = "sequence"
26
+ seed = 42
27
 
28
 
29
  # Assuming 'predict_stability' is your function that predicts protein stability
 
92
  model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg)
93
  model.to(cfg.device)
94
 
95
+ # predictions = predict_fn(loader, model, cfg)
96
  model.eval()
97
  predictions = []
98
 
99
  for inputs, _ in dataloader:
100
  inputs = inputs.to(cfg.device)
101
  with torch.no_grad():
102
+ with torch.amp.autocast(cfg.device, enabled=cfg.use_amp):
103
  preds = (
104
  torch.sigmoid(model(inputs))
105
  if cfg.task == "classification"
 
107
  )
108
  predictions += preds.cpu().tolist()
109
  outputs = {}
110
+ predictions = list(itertools.chain.from_iterable(predictions))
111
  outputs["raw prediction values"] = predictions
112
  outputs["binary prediction values"] = [1 if x > 0.5 else 0 for x in predictions]
113
+ return outputs
114
 
115
+ predict_stability("SaProt", "Human", sequence="MELKQK")