PLTNUM / app.py
sagawa's picture
Update app.py
e4d81ca verified
raw
history blame
3.51 kB
import sys
import random
import os
import pandas as pd
import torch
import itertools
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
sys.path.append("scripts/")
from foldseek_util import get_struc_seq
from utils import seed_everything
from models import PLTNUM_PreTrainedModel
from datasets_ import PLTNUMDataset
class Config:
batch_size = 2
use_amp = False
num_workers = 1
max_length = 512
used_sequence = "left"
padding_side = "right"
task = "classification"
sequence_col = "sequence"
seed = 42
# Assuming 'predict_stability' is your function that predicts protein stability
def predict_stability(model_choice, organism_choice, pdb_file=None, sequence=None, cfg=Config()):
# Check if pdb_file is provided
if pdb_file:
pdb_path = pdb_file.name # Get the path of the uploaded PDB file
os.system("chmod 777 bin/foldseek")
sequences = get_foldseek_seq(pdb_path)
if not sequences:
return "Failed to extract sequence from the PDB file."
if model_choice == "SaProt":
sequence = sequences[2]
else:
sequence = sequences[0]
if organism_choice == "Human":
cell_line = "HeLa"
else:
cell_line = "NIH3T3"
# If sequence is provided directly
if sequence:
cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
cfg.architecture = model_choice
cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}"
output = predict(cfg, sequence)
return f"Predicted Stability using {model_choice} for {organism_choice}: Example Output with sequence {output}..."
else:
return "No valid input provided."
def get_foldseek_seq(pdb_path):
parsed_seqs = get_struc_seq(
"bin/foldseek",
pdb_path,
["A"],
process_id=random.randint(0, 10000000),
)["A"]
return parsed_seqs
def predict(cfg, sequence):
cfg.token_length = 2 if cfg.architecture == "SaProt" else 1
cfg.device = "cuda" if torch.cuda.is_available() else "cpu"
if cfg.used_sequence == "both":
cfg.max_length += 1
seed_everything(cfg.seed)
df = pd.DataFrame({cfg.sequence_col: [sequence]})
tokenizer = AutoTokenizer.from_pretrained(
cfg.model_path, padding_side=cfg.padding_side
)
cfg.tokenizer = tokenizer
dataset = PLTNUMDataset(cfg, df, train=False)
dataloader = DataLoader(
dataset,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=True,
drop_last=False,
)
model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg)
model.to(cfg.device)
# predictions = predict_fn(loader, model, cfg)
model.eval()
predictions = []
for inputs, _ in dataloader:
inputs = inputs.to(cfg.device)
with torch.no_grad():
with torch.amp.autocast(cfg.device, enabled=cfg.use_amp):
preds = (
torch.sigmoid(model(inputs))
if cfg.task == "classification"
else model(inputs)
)
predictions += preds.cpu().tolist()
outputs = {}
predictions = list(itertools.chain.from_iterable(predictions))
outputs["raw prediction values"] = predictions
outputs["binary prediction values"] = [1 if x > 0.5 else 0 for x in predictions]
return outputs
predict_stability("SaProt", "Human", sequence="MELKQK")