🚀 SaProt
SaProt is a model that can be used for protein-related tasks. It provides multiple usage methods and can achieve good performance in most tasks with structural input.
🚀 Quick Start
We provide two ways to use SaProt, including through the Hugging Face class and through the same way as in esm github. Users can choose either one to use.
⚠️ Important Note
SaProt requires structural (SA token) input for optimal performance. AA-sequence-only mode works but must be finetuned - frozen embeddings work only for SA, not AA sequences! With structural input, SaProt surpasses ESM2 in most tasks.
✨ Features
💻 Usage Examples
Basic Usage
We will introduce how to use SaProt through Hugging Face and the ESM model respectively, and also show how to predict mutational effects and get protein embeddings.
Advanced Usage
The following shows the specific code examples.
Huggingface model
The following code shows how to load the model.
from transformers import EsmTokenizer, EsmForMaskedLM
model_path = "/your/path/to/SaProt_650M_AF2"
tokenizer = EsmTokenizer.from_pretrained(model_path)
model = EsmForMaskedLM.from_pretrained(model_path)
device = "cuda"
model.to(device)
seq = "M#EvVpQpL#VyQdYaKv"
tokens = tokenizer.tokenize(seq)
print(tokens)
inputs = tokenizer(seq, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
print(outputs.logits.shape)
"""
['M#', 'Ev', 'Vp', 'Qp', 'L#', 'Vy', 'Qd', 'Ya', 'Kv']
torch.Size([1, 11, 446])
"""
esm model
The esm version is also stored in the same folder, named SaProt_650M_AF2.pt
. We provide a function to load the model.
from utils.esm_loader import load_esm_saprot
model_path = "/your/path/to/SaProt_650M_AF2.pt"
model, alphabet = load_esm_saprot(model_path)
Predict mutational effect
We provide a function to predict the mutational effect of a protein sequence. The example below shows how to predict
the mutational effect at a specific position. If using the AF2 structure, we strongly recommend that you add pLDDT mask (see below).
from model.saprot.saprot_foldseek_mutation_model import SaprotFoldseekMutationModel
config = {
"foldseek_path": None,
"config_path": "/your/path/to/SaProt_650M_AF2",
"load_pretrained": True,
}
model = SaprotFoldseekMutationModel(**config)
tokenizer = model.tokenizer
device = "cuda"
model.eval()
model.to(device)
seq = "M#EvVpQpL#VyQdYaKv"
mut_info = "V3A"
mut_value = model.predict_mut(seq, mut_info)
print(mut_value)
mut_info = "V3A:Q4M"
mut_value = model.predict_mut(seq, mut_info)
print(mut_value)
mut_pos = 3
mut_dict = model.predict_pos_mut(seq, mut_pos)
print(mut_dict)
mut_pos = 3
mut_dict = model.predict_pos_prob(seq, mut_pos)
print(mut_dict)
Get protein embeddings
If you want to generate protein embeddings, you could refer to the following code. The embeddings are the average of
the hidden states of the last layer.
from model.saprot.base import SaprotBaseModel
from transformers import EsmTokenizer
config = {
"task": "base",
"config_path": "/your/path/to/SaProt_650M_AF2",
"load_pretrained": True,
}
model = SaprotBaseModel(**config)
tokenizer = EsmTokenizer.from_pretrained(config["config_path"])
device = "cuda"
model.to(device)
seq = "M#EvVpQpL#VyQdYaKv"
tokens = tokenizer.tokenize(seq)
print(tokens)
inputs = tokenizer(seq, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
embeddings = model.get_hidden_states(inputs, reduction="mean")
print(embeddings[0].shape)
📄 License
This project is licensed under the MIT license.