oucgc1996's picture
Update app.py
777de14 verified
import VolumeMaker
import utils
import numpy as np
import random
import torch
import torch.nn as nn
import pathlib
import pandas as pd
import shutil
import subprocess
from transformers import AutoModelForSequenceClassification
from torch.utils.data import Dataset,DataLoader
import pandas as pd
device = torch.device("cpu")
import os
join=os.path.join
from transformers import AutoTokenizer
import torch.nn.functional as F
from rdkit import Chem
from rdkit.Chem import AllChem
from collections import OrderedDict
from tqdm import tqdm
import time
import gradio as gr
model_checkpoint = "facebook/esm2_t6_8M_UR50D"
pdb_path = pathlib.Path(__file__).parent.joinpath("structure" )
# seq_path = "test3.csv"
temp_path = pathlib.Path(__file__).parent.joinpath("temp" )
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
setup_seed(4)
batch_size = 1
num_labels = 2
radius = 2
n_features = 1024
hid_dim = 300
n_heads = 1
dropout = 0
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
class MyDataset(Dataset):
def __init__(self,dict_data) -> None:
super(MyDataset,self).__init__()
self.data=dict_data
self.structure=pdb_structure(dict_data['structure'])
def __getitem__(self, index):
return self.data['text'][index], self.structure[index]
def __len__(self):
return len(self.data['text'])
def collate_fn(batch):
data = [item[0] for item in batch]
structure = torch.tensor([item[1].tolist() for item in batch]).to(device)
max_len = max([len(b[0]) for b in batch])+2
fingerprint = torch.tensor(peptides_to_fingerprint_matrix(data, radius, n_features),dtype=float).to(device)
pt_batch=tokenizer(data, padding=True, truncation=True, max_length=max_len, return_tensors='pt')
return {'input_ids':pt_batch['input_ids'].to(device),
'attention_mask':pt_batch['attention_mask'].to(device)}, structure, fingerprint
class AttentionBlock(nn.Module):
def __init__(self, hid_dim, n_heads, dropout):
super().__init__()
self.hid_dim = hid_dim
self.n_heads = n_heads
assert hid_dim % n_heads == 0
self.f_q = nn.Linear(hid_dim, hid_dim)
self.f_k = nn.Linear(hid_dim, hid_dim)
self.f_v = nn.Linear(hid_dim, hid_dim)
self.fc = nn.Linear(hid_dim, hid_dim)
self.do = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device)
def forward(self, query, key, value, mask=None):
batch_size = query.shape[0]
Q = self.f_q(query)
K = self.f_k(key)
V = self.f_v(value)
Q = Q.view(batch_size, self.n_heads, self.hid_dim // self.n_heads).unsqueeze(3)
K_T = K.view(batch_size, self.n_heads, self.hid_dim // self.n_heads).unsqueeze(3).transpose(2,3)
V = V.view(batch_size, self.n_heads, self.hid_dim // self.n_heads).unsqueeze(3)
energy = torch.matmul(Q, K_T) / self.scale
if mask is not None:
energy = energy.masked_fill(mask == 0, -1e10)
attention = self.do(F.softmax(energy, dim=-1))
weighter_matrix = torch.matmul(attention, V)
weighter_matrix = weighter_matrix.permute(0, 2, 1, 3).contiguous()
weighter_matrix = weighter_matrix.view(batch_size, self.n_heads * (self.hid_dim // self.n_heads))
weighter_matrix = self.do(self.fc(weighter_matrix))
return weighter_matrix
class CrossAttentionBlock(nn.Module):
def __init__(self):
super(CrossAttentionBlock, self).__init__()
self.att = AttentionBlock(hid_dim = hid_dim, n_heads = n_heads, dropout=0.1)
def forward(self, structure_feature, fingerprint_feature, sequence_feature):
# cross attention for compound information enrichment
fingerprint_feature = fingerprint_feature + self.att(fingerprint_feature, structure_feature, structure_feature)
# self-attention
fingerprint_feature = self.att(fingerprint_feature, fingerprint_feature, fingerprint_feature)
# cross-attention for interaction
output = self.att(fingerprint_feature, sequence_feature, sequence_feature)
return output
def peptides_to_fingerprint_matrix(peptides, radius=radius, n_features=n_features):
n_peptides = len(peptides)
features = np.zeros((n_peptides, n_features))
for i, peptide in enumerate(peptides):
mol = Chem.MolFromSequence(peptide)
fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_features)
fp_array = np.zeros((1,))
AllChem.DataStructs.ConvertToNumpyArray(fp, fp_array)
features[i, :] = fp_array
return features
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.bert = AutoModelForSequenceClassification.from_pretrained(model_checkpoint,num_labels=hid_dim)
self.bn1 = nn.BatchNorm1d(256)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(64)
self.relu = nn.ReLU()
self.fc1 = nn.Linear(300,256)
self.fc2 = nn.Linear(256,128)
self.fc3 = nn.Linear(128,64)
self.fc_fingerprint = nn.Linear(1024,hid_dim)
self.fc_structure = nn.Linear(1500,hid_dim)
self.fingerprint_lstm = nn.LSTM(bidirectional=True, num_layers=2, input_size=1024, hidden_size=1024//2, batch_first=True)
self.structure_lstm = nn.LSTM(bidirectional=True, num_layers=2, input_size=500, hidden_size=500//2, batch_first=True)
self.output_layer = nn.Linear(64,num_labels)
self.dropout = nn.Dropout(0)
self.CAB = CrossAttentionBlock()
def forward(self,structure, x, fingerprint):
fingerprint = torch.unsqueeze(fingerprint, 2).float()
structure = structure.permute(0, 2, 1)
fingerprint = fingerprint.permute(0, 2, 1)
with torch.no_grad():
bert_output = self.bert(input_ids=x['input_ids'].to(device),attention_mask=x['attention_mask'].to(device))
sequence_feature = self.dropout(bert_output["logits"])
structure = structure.to(device)
fingerprint_feature, _ = self.fingerprint_lstm(fingerprint)
structure_feature, _ = self.structure_lstm(structure)
fingerprint_feature = fingerprint_feature.flatten(start_dim=1)
structure_feature = structure_feature.flatten(start_dim=1)
fingerprint_feature = self.fc_fingerprint(fingerprint_feature)
structure_feature = self.fc_structure(structure_feature)
output_feature = self.CAB(structure_feature, fingerprint_feature, sequence_feature)
output_feature = self.dropout(self.relu(self.bn1(self.fc1(output_feature))))
output_feature = self.dropout(self.relu(self.bn2(self.fc2(output_feature))))
output_feature = self.dropout(self.relu(self.bn3(self.fc3(output_feature))))
output_feature = self.dropout(self.output_layer(output_feature))
return torch.softmax(output_feature,dim=1)
def pdb_structure(Structure_index):
created_folders = []
SurfacePoitCloud_all = []
for index in Structure_index:
structure_folder = join(temp_path, str(index))
os.makedirs(structure_folder, exist_ok=True)
created_folders.append(structure_folder)
pdb_file = join(pdb_path, f"{index}.pdb")
if os.path.exists(pdb_file):
shutil.copy2(pdb_file, structure_folder)
else:
print(f"PDB file not found for structure {index}")
coords, atname, pdbname, pdb_num = utils.parsePDB(structure_folder)
atoms_channel = utils.atomlistToChannels(atname)
radius = utils.atomlistToRadius(atname)
PointCloudSurfaceObject = VolumeMaker.PointCloudSurface(device=device)
coords = coords.to(device)
radius = radius.to(device)
atoms_channel = atoms_channel.to(device)
SurfacePoitCloud = PointCloudSurfaceObject(coords, radius)
feature = SurfacePoitCloud.view(pdb_num,-1,3).cpu()
SurfacePoitCloud_all.append(feature)
SurfacePoitCloud_all_tensor = torch.squeeze(torch.stack(SurfacePoitCloud_all),dim=1)
for folder in created_folders:
shutil.rmtree(folder)
return SurfacePoitCloud_all_tensor
def ACE(file):
if not os.path.exists(pdb_path):
os.makedirs(pdb_path)
else:
shutil.rmtree(pdb_path)
os.makedirs(pdb_path)
test_sequences = [file]
test_Structure_index = [f"structure_{i}" for i in range(len(test_sequences))]
test_dict = {"text":test_sequences, 'structure':test_Structure_index}
print("=================================Structure prediction========================")
for i in tqdm(range(0, len(test_sequences))):
command = ["curl", "-X", "POST", "-k", "--data", f"{test_sequences[i]}", "https://api.esmatlas.com/foldSequence/v1/pdb/"]
result = subprocess.run(command, capture_output=True, text=True)
with open(os.path.join(pdb_path, f'{test_Structure_index[i]}.pdb'), 'w') as file:
file.write(result.stdout)
test_data=MyDataset(test_dict)
test_dataloader=DataLoader(test_data,batch_size=batch_size,collate_fn=collate_fn,shuffle=False)
# 导入模型
model = MyModel()
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device('cpu')), strict=False)
model = model.to(device)
# 预测
model.eval()
with torch.no_grad():
probability_all = []
Target_all = []
print("=================================Start prediction========================")
for index, (batch, structure_fea, fingerprint) in enumerate(test_dataloader):
batchs = {k: v for k, v in batch.items()}
outputs = model(structure_fea, batchs, fingerprint)
probability = outputs[0].tolist()
train_argmax = np.argmax(outputs.cpu().detach().numpy(), axis=1)
for j in range(0,len(train_argmax)):
output = train_argmax[j]
if output == 0:
Target = "low"
probability = probability[0]
elif output == 1:
Target = "high"
probability = probability[1]
out_text = Target
out_prob = probability
return out_text, out_prob
with open("ACE.md", "r") as f:
description = f.read()
iface = gr.Interface(fn=ACE,
title="🏹DeepACE",
inputs=gr.Textbox(show_label=False, placeholder="Enter peptide only", lines=4),
outputs= [gr.Textbox(show_label=False, placeholder="class", lines=1), gr.Textbox(show_label=False, placeholder="probability", lines=1)],
description=description)
iface.launch()