UTR_LM / esm /model /esm2_secondarystructure.py
Shawn Shen
minor fixes
5aa3fcd
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
import torch
import torch.nn as nn
import esm
from esm.modules import ContactPredictionHead, ESM1bLayerNorm, RobertaLMHead, TransformerLayer
# ```该代码定义了一个名为 ESM2 的 PyTorch 模型,继承自 nn.Module。在 __init__ 方法中,定义了一些超参数,例如 num_layers、embed_dim、attention_heads 等等。同时,它还初始化了一些子模块,例如 Embedding 层 embed_tokens、一系列 Transformer 层 layers、预测接触的 ContactPredictionHead 层 contact_head,以及一些线性层 lm_head、supervised_linear、structure_linear 等。该模型的前向传播在 forward 方法中定义,接收一个表示序列的 token 序列 tokens,返回预测的标签和其他附加信息。```
class ESM2(nn.Module):
def __init__(
self,
num_layers: int = 33,
embed_dim: int = 1280,
attention_heads: int = 20,
alphabet: Union[esm.data.Alphabet, str] = "ESM-1b",
token_dropout: bool = True,
):
super().__init__()
self.num_layers = num_layers
self.embed_dim = embed_dim
self.attention_heads = attention_heads
if not isinstance(alphabet, esm.data.Alphabet):
alphabet = esm.data.Alphabet.from_architecture(alphabet)
self.alphabet = alphabet
self.alphabet_size = len(alphabet)
self.padding_idx = alphabet.padding_idx
self.mask_idx = alphabet.mask_idx
self.cls_idx = alphabet.cls_idx
self.eos_idx = alphabet.eos_idx
self.prepend_bos = alphabet.prepend_bos
self.append_eos = alphabet.append_eos
self.token_dropout = token_dropout
self._init_submodules()
def _init_submodules(self):
self.embed_scale = 1
self.embed_tokens = nn.Embedding(
self.alphabet_size,
self.embed_dim,
padding_idx=self.padding_idx,
)
self.layers = nn.ModuleList(
[
TransformerLayer(
self.embed_dim,
4 * self.embed_dim,
self.attention_heads,
add_bias_kv=False,
use_esm1b_layer_norm=True,
use_rotary_embeddings=True,
)
for _ in range(self.num_layers)
]
)
self.contact_head = ContactPredictionHead(
self.num_layers * self.attention_heads,
self.prepend_bos,
self.append_eos,
eos_idx=self.eos_idx,
)
self.emb_layer_norm_after = ESM1bLayerNorm(self.embed_dim)
self.lm_head = RobertaLMHead(
embed_dim=self.embed_dim,
output_dim=self.alphabet_size,
weight=self.embed_tokens.weight,
)
self.supervised_linear = nn.Linear(self.embed_dim, 1)
self.structure_linear = nn.Linear(self.embed_dim, 3)
def forward(self, tokens, repr_layers=[], need_head_weights=True, return_contacts=True, return_representation=True, return_attentions_symm = False, return_attentions = False):
if return_contacts:
need_head_weights = True
assert tokens.ndim == 2
padding_mask = tokens.eq(self.padding_idx) # B, T
x = self.embed_scale * self.embed_tokens(tokens)
if self.token_dropout:
x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
#print(f'tokens = {tokens}')
#print(f'self.mask_idx = {self.mask_idx}')
#print('x.shape = ', x.shape)
# x: B x T x C
mask_ratio_train = 0.15 * 0.8
src_lengths = (~padding_mask).sum(-1)
#print(f'mask_ratio_train = {mask_ratio_train}')
#print(f'padding_mask = {padding_mask}')
#print(f'src_lengths = {src_lengths}')
mask_ratio_observed = (tokens == self.mask_idx).sum(-1).to(x.dtype) / src_lengths
#print('mask_ratio_observed = ',mask_ratio_observed)
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
#print(f'x.shape = {x.shape}:\n', x)
if padding_mask is not None:
x = x * (1 - padding_mask.unsqueeze(-1).type_as(x))
#print(f'x.shape = {x.shape}:\n', x)
repr_layers = set(repr_layers)
hidden_representations = {}
if 0 in repr_layers:
hidden_representations[0] = x
if need_head_weights:
attn_weights = []
# (B, T, E) => (T, B, E)
x = x.transpose(0, 1)
if not padding_mask.any():
padding_mask = None
for layer_idx, layer in enumerate(self.layers):
x, attn = layer(
x,
self_attn_padding_mask=padding_mask,
need_head_weights=need_head_weights,
)
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x.transpose(0, 1)
if need_head_weights:
# (H, B, T, T) => (B, H, T, T)
attn_weights.append(attn.transpose(1, 0))
# print(x.shape) # 73, 2, 1280
x = self.emb_layer_norm_after(x)
x = x.transpose(0, 1) # (T, B, E) => (B, T, E)
# last hidden representation should have layer norm applied
if (layer_idx + 1) in repr_layers:
hidden_representations[layer_idx + 1] = x
x_supervised = self.supervised_linear(x[:,0,:])
x_structure = self.structure_linear(x)
x = self.lm_head(x)
if return_representation:
result = {"logits": x, "logits_supervised": x_supervised, "logits_structure": x_structure, "representations": hidden_representations}
else:
result = {"logits": x, "logits_supervised": x_supervised, "logits_structure": x_structure}
if need_head_weights:
# attentions: B x L x H x T x T
attentions = torch.stack(attn_weights, 1)
if padding_mask is not None:
attention_mask = 1 - padding_mask.type_as(attentions)
attention_mask = attention_mask.unsqueeze(1) * attention_mask.unsqueeze(2)
attentions = attentions * attention_mask[:, None, None, :, :]
if return_attentions: result["attentions"] = attentions
if return_contacts:
attentions_symm, contacts = self.contact_head(tokens, attentions)
result["contacts"] = contacts
if return_attentions_symm: result["attentions_symm"] = attentions_symm
return result
def predict_contacts(self, tokens):
return self(tokens, return_contacts=True)["contacts"]
def predict_symmetric_attentions(self, tokens):
return self(tokens, return_contacts=True)["attentions_symm"]
def predict_attentions(self, tokens):
return self(tokens, need_head_weights=True)["attentions"]
def predict_representations(self, tokens):
return self(tokens, return_representation=True)['representations']
def predict_logits(self, tokens):
return self(tokens)['logits']
def predict_logits_supervised(self, tokens):
return self(tokens)['logits_supervised']
def predict_logits_structure(self, tokens):
return self(tokens)['logits_structure']