import sys import random import os import pandas as pd import torch import itertools from torch.utils.data import DataLoader from transformers import AutoTokenizer sys.path.append("scripts/") from foldseek_util import get_struc_seq from utils import seed_everything from models import PLTNUM_PreTrainedModel from datasets_ import PLTNUMDataset class Config: batch_size = 2 use_amp = False num_workers = 1 max_length = 512 used_sequence = "left" padding_side = "right" task = "classification" sequence_col = "sequence" seed = 42 # Assuming 'predict_stability' is your function that predicts protein stability def predict_stability(model_choice, organism_choice, pdb_file=None, sequence=None, cfg=Config()): # Check if pdb_file is provided if pdb_file: pdb_path = pdb_file.name # Get the path of the uploaded PDB file os.system("chmod 777 bin/foldseek") sequences = get_foldseek_seq(pdb_path) if not sequences: return "Failed to extract sequence from the PDB file." if model_choice == "SaProt": sequence = sequences[2] else: sequence = sequences[0] if organism_choice == "Human": cell_line = "HeLa" else: cell_line = "NIH3T3" # If sequence is provided directly if sequence: cfg.model = f"sagawa/PLTNUM-{model_choice}-{cell_line}" cfg.architecture = model_choice cfg.model_path = f"sagawa/PLTNUM-{model_choice}-{cell_line}" output = predict(cfg, sequence) return f"Predicted Stability using {model_choice} for {organism_choice}: Example Output with sequence {output}..." else: return "No valid input provided." def get_foldseek_seq(pdb_path): parsed_seqs = get_struc_seq( "bin/foldseek", pdb_path, ["A"], process_id=random.randint(0, 10000000), )["A"] return parsed_seqs def predict(cfg, sequence): cfg.token_length = 2 if cfg.architecture == "SaProt" else 1 cfg.device = "cuda" if torch.cuda.is_available() else "cpu" if cfg.used_sequence == "both": cfg.max_length += 1 seed_everything(cfg.seed) df = pd.DataFrame({cfg.sequence_col: [sequence]}) tokenizer = AutoTokenizer.from_pretrained( cfg.model_path, padding_side=cfg.padding_side ) cfg.tokenizer = tokenizer dataset = PLTNUMDataset(cfg, df, train=False) dataloader = DataLoader( dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True, drop_last=False, ) model = PLTNUM_PreTrainedModel.from_pretrained(cfg.model_path, cfg=cfg) model.to(cfg.device) # predictions = predict_fn(loader, model, cfg) model.eval() predictions = [] for inputs, _ in dataloader: inputs = inputs.to(cfg.device) with torch.no_grad(): with torch.amp.autocast(cfg.device, enabled=cfg.use_amp): preds = ( torch.sigmoid(model(inputs)) if cfg.task == "classification" else model(inputs) ) predictions += preds.cpu().tolist() outputs = {} predictions = list(itertools.chain.from_iterable(predictions)) outputs["raw prediction values"] = predictions outputs["binary prediction values"] = [1 if x > 0.5 else 0 for x in predictions] return outputs predict_stability("SaProt", "Human", sequence="MELKQK")