Spaces:
Running
Running
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import re | |
import urllib | |
import warnings | |
from argparse import Namespace | |
from pathlib import Path | |
import torch | |
import esm | |
from esm.model.esm2 import ESM2 | |
def _has_regression_weights(model_name): | |
"""Return whether we expect / require regression weights; | |
Right now that is all models except ESM-1v and ESM-IF""" | |
return not ("esm1v" in model_name or "esm_if" in model_name) | |
def load_model_and_alphabet(model_name): | |
if model_name.endswith(".pt"): # treat as filepath | |
return load_model_and_alphabet_local(model_name) | |
else: | |
return load_model_and_alphabet_hub(model_name) | |
def load_hub_workaround(url): | |
try: | |
data = torch.hub.load_state_dict_from_url(url, progress=False, map_location="cpu") | |
except RuntimeError: | |
# Pytorch version issue - see https://github.com/pytorch/pytorch/issues/43106 | |
fn = Path(url).name | |
data = torch.load( | |
f"{torch.hub.get_dir()}/checkpoints/{fn}", | |
map_location="cpu", | |
) | |
except urllib.error.HTTPError as e: | |
raise Exception(f"Could not load {url}, check if you specified a correct model name?") | |
return data | |
def load_regression_hub(model_name): | |
url = f"https://dl.fbaipublicfiles.com/fair-esm/regression/{model_name}-contact-regression.pt" | |
regression_data = load_hub_workaround(url) | |
return regression_data | |
def _download_model_and_regression_data(model_name): | |
url = f"https://dl.fbaipublicfiles.com/fair-esm/models/{model_name}.pt" | |
model_data = load_hub_workaround(url) | |
if _has_regression_weights(model_name): | |
regression_data = load_regression_hub(model_name) | |
else: | |
regression_data = None | |
return model_data, regression_data | |
def load_model_and_alphabet_hub(model_name): | |
model_data, regression_data = _download_model_and_regression_data(model_name) | |
return load_model_and_alphabet_core(model_name, model_data, regression_data) | |
def load_model_and_alphabet_local(model_location): | |
"""Load from local path. The regression weights need to be co-located""" | |
model_location = Path(model_location) | |
model_data = torch.load(str(model_location), map_location="cpu") | |
model_name = model_location.stem | |
if _has_regression_weights(model_name): | |
regression_location = str(model_location.with_suffix("")) + "-contact-regression.pt" | |
regression_data = torch.load(regression_location, map_location="cpu") | |
else: | |
regression_data = None | |
return load_model_and_alphabet_core(model_name, model_data, regression_data) | |
def has_emb_layer_norm_before(model_state): | |
"""Determine whether layer norm needs to be applied before the encoder""" | |
return any(k.startswith("emb_layer_norm_before") for k, param in model_state.items()) | |
def _load_model_and_alphabet_core_v1(model_data): | |
import esm # since esm.inverse_folding is imported below, you actually have to re-import esm here | |
alphabet = esm.Alphabet.from_architecture(model_data["args"].arch) | |
if model_data["args"].arch == "roberta_large": | |
# upgrade state dict | |
pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s) | |
prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s) | |
prs2 = lambda s: "".join( | |
s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s | |
) | |
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()} | |
model_state = {prs1(prs2(arg[0])): arg[1] for arg in model_data["model"].items()} | |
model_state["embed_tokens.weight"][alphabet.mask_idx].zero_() # For token drop | |
model_args["emb_layer_norm_before"] = has_emb_layer_norm_before(model_state) | |
model_type = esm.ProteinBertModel | |
elif model_data["args"].arch == "protein_bert_base": | |
# upgrade state dict | |
pra = lambda s: "".join(s.split("decoder_")[1:] if "decoder" in s else s) | |
prs = lambda s: "".join(s.split("decoder.")[1:] if "decoder" in s else s) | |
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()} | |
model_state = {prs(arg[0]): arg[1] for arg in model_data["model"].items()} | |
model_type = esm.ProteinBertModel | |
elif model_data["args"].arch == "msa_transformer": | |
# upgrade state dict | |
pra = lambda s: "".join(s.split("encoder_")[1:] if "encoder" in s else s) | |
prs1 = lambda s: "".join(s.split("encoder.")[1:] if "encoder" in s else s) | |
prs2 = lambda s: "".join( | |
s.split("sentence_encoder.")[1:] if "sentence_encoder" in s else s | |
) | |
prs3 = lambda s: s.replace("row", "column") if "row" in s else s.replace("column", "row") | |
model_args = {pra(arg[0]): arg[1] for arg in vars(model_data["args"]).items()} | |
model_state = {prs1(prs2(prs3(arg[0]))): arg[1] for arg in model_data["model"].items()} | |
if model_args.get("embed_positions_msa", False): | |
emb_dim = model_state["msa_position_embedding"].size(-1) | |
model_args["embed_positions_msa_dim"] = emb_dim # initial release, bug: emb_dim==1 | |
model_type = esm.MSATransformer | |
elif "invariant_gvp" in model_data["args"].arch: | |
import esm.inverse_folding | |
model_type = esm.inverse_folding.gvp_transformer.GVPTransformerModel | |
model_args = vars(model_data["args"]) # convert Namespace -> dict | |
def update_name(s): | |
# Map the module names in checkpoints trained with internal code to | |
# the updated module names in open source code | |
s = s.replace("W_v", "embed_graph.embed_node") | |
s = s.replace("W_e", "embed_graph.embed_edge") | |
s = s.replace("embed_scores.0", "embed_confidence") | |
s = s.replace("embed_score.", "embed_graph.embed_confidence.") | |
s = s.replace("seq_logits_projection.", "") | |
s = s.replace("embed_ingraham_features", "embed_dihedrals") | |
s = s.replace("embed_gvp_in_local_frame.0", "embed_gvp_output") | |
s = s.replace("embed_features_in_local_frame.0", "embed_gvp_input_features") | |
return s | |
model_state = { | |
update_name(sname): svalue | |
for sname, svalue in model_data["model"].items() | |
if "version" not in sname | |
} | |
else: | |
raise ValueError("Unknown architecture selected") | |
model = model_type( | |
Namespace(**model_args), | |
alphabet, | |
) | |
return model, alphabet, model_state | |
def _load_model_and_alphabet_core_v2(model_data): | |
def upgrade_state_dict(state_dict): | |
"""Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'.""" | |
prefixes = ["encoder.sentence_encoder.", "encoder."] | |
pattern = re.compile("^" + "|".join(prefixes)) | |
state_dict = {pattern.sub("", name): param for name, param in state_dict.items()} | |
return state_dict | |
cfg = model_data["cfg"]["model"] | |
state_dict = model_data["model"] | |
state_dict = upgrade_state_dict(state_dict) | |
alphabet = esm.data.Alphabet.from_architecture("ESM-1b") | |
model = ESM2( | |
num_layers=cfg.encoder_layers, | |
embed_dim=cfg.encoder_embed_dim, | |
attention_heads=cfg.encoder_attention_heads, | |
alphabet=alphabet, | |
token_dropout=cfg.token_dropout, | |
) | |
return model, alphabet, state_dict | |
def load_model_and_alphabet_core(model_name, model_data, regression_data=None): | |
if regression_data is not None: | |
model_data["model"].update(regression_data["model"]) | |
if model_name.startswith("esm2"): | |
model, alphabet, model_state = _load_model_and_alphabet_core_v2(model_data) | |
else: | |
model, alphabet, model_state = _load_model_and_alphabet_core_v1(model_data) | |
expected_keys = set(model.state_dict().keys()) | |
found_keys = set(model_state.keys()) | |
if regression_data is None: | |
expected_missing = {"contact_head.regression.weight", "contact_head.regression.bias"} | |
error_msgs = [] | |
missing = (expected_keys - found_keys) - expected_missing | |
if missing: | |
error_msgs.append(f"Missing key(s) in state_dict: {missing}.") | |
unexpected = found_keys - expected_keys | |
if unexpected: | |
error_msgs.append(f"Unexpected key(s) in state_dict: {unexpected}.") | |
if error_msgs: | |
raise RuntimeError( | |
"Error(s) in loading state_dict for {}:\n\t{}".format( | |
model.__class__.__name__, "\n\t".join(error_msgs) | |
) | |
) | |
if expected_missing - found_keys: | |
warnings.warn( | |
"Regression weights not found, predicting contacts will not produce correct results." | |
) | |
model.load_state_dict(model_state, strict=regression_data is not None) | |
return model, alphabet | |
def esm1_t34_670M_UR50S(): | |
"""34 layer transformer model with 670M params, trained on Uniref50 Sparse. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm1_t34_670M_UR50S") | |
def esm1_t34_670M_UR50D(): | |
"""34 layer transformer model with 670M params, trained on Uniref50 Dense. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm1_t34_670M_UR50D") | |
def esm1_t34_670M_UR100(): | |
"""34 layer transformer model with 670M params, trained on Uniref100. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm1_t34_670M_UR100") | |
def esm1_t12_85M_UR50S(): | |
"""12 layer transformer model with 85M params, trained on Uniref50 Sparse. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm1_t12_85M_UR50S") | |
def esm1_t6_43M_UR50S(): | |
"""6 layer transformer model with 43M params, trained on Uniref50 Sparse. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm1_t6_43M_UR50S") | |
def esm1b_t33_650M_UR50S(): | |
"""33 layer transformer model with 650M params, trained on Uniref50 Sparse. | |
This is our best performing model, which will be described in a future publication. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm1b_t33_650M_UR50S") | |
def esm_msa1_t12_100M_UR50S(): | |
warnings.warn( | |
"This model had a minor bug in the positional embeddings, " | |
"please use ESM-MSA-1b: esm.pretrained.esm_msa1b_t12_100M_UR50S()", | |
) | |
return load_model_and_alphabet_hub("esm_msa1_t12_100M_UR50S") | |
def esm_msa1b_t12_100M_UR50S(): | |
return load_model_and_alphabet_hub("esm_msa1b_t12_100M_UR50S") | |
def esm1v_t33_650M_UR90S(): | |
"""33 layer transformer model with 650M params, trained on Uniref90. | |
This is model 1 of a 5 model ensemble. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1") | |
def esm1v_t33_650M_UR90S_1(): | |
"""33 layer transformer model with 650M params, trained on Uniref90. | |
This is model 1 of a 5 model ensemble. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_1") | |
def esm1v_t33_650M_UR90S_2(): | |
"""33 layer transformer model with 650M params, trained on Uniref90. | |
This is model 2 of a 5 model ensemble. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_2") | |
def esm1v_t33_650M_UR90S_3(): | |
"""33 layer transformer model with 650M params, trained on Uniref90. | |
This is model 3 of a 5 model ensemble. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_3") | |
def esm1v_t33_650M_UR90S_4(): | |
"""33 layer transformer model with 650M params, trained on Uniref90. | |
This is model 4 of a 5 model ensemble. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_4") | |
def esm1v_t33_650M_UR90S_5(): | |
"""33 layer transformer model with 650M params, trained on Uniref90. | |
This is model 5 of a 5 model ensemble. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm1v_t33_650M_UR90S_5") | |
def esm_if1_gvp4_t16_142M_UR50(): | |
"""Inverse folding model with 142M params, with 4 GVP-GNN layers, 8 | |
Transformer encoder layers, and 8 Transformer decoder layers, trained on | |
CATH structures and 12 million alphafold2 predicted structures from UniRef50 | |
sequences. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm_if1_gvp4_t16_142M_UR50") | |
def esm2_t6_8M_UR50D(): | |
"""6 layer ESM-2 model with 8M params, trained on UniRef50. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm2_t6_8M_UR50D") | |
def esm2_t12_35M_UR50D(): | |
"""12 layer ESM-2 model with 35M params, trained on UniRef50. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm2_t12_35M_UR50D") | |
def esm2_t30_150M_UR50D(): | |
"""30 layer ESM-2 model with 150M params, trained on UniRef50. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm2_t30_150M_UR50D") | |
def esm2_t33_650M_UR50D(): | |
"""33 layer ESM-2 model with 650M params, trained on UniRef50. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm2_t33_650M_UR50D") | |
def esm2_t36_3B_UR50D(): | |
"""36 layer ESM-2 model with 3B params, trained on UniRef50. | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm2_t36_3B_UR50D") | |
def esm2_t48_15B_UR50D(): | |
"""48 layer ESM-2 model with 15B params, trained on UniRef50. | |
If you have OOM while loading this model, please refer to README | |
on how to employ FSDP and ZeRO CPU offloading | |
Returns a tuple of (Model, Alphabet). | |
""" | |
return load_model_and_alphabet_hub("esm2_t48_15B_UR50D") |