🚀 SaProt模型使用说明
SaProt是一个在蛋白质相关任务中表现出色的模型,它需要结构(SA令牌)输入以达到最佳性能。仅使用氨基酸序列模式虽然可行,但必须进行微调,因为冻结的嵌入仅适用于SA,而不适用于氨基酸序列。在有结构输入的情况下,SaProt在大多数任务中超越了ESM2。本项目提供了两种使用SaProt的方式,用户可按需选择。
🚀 快速开始
我们提供了两种使用SaProt的方式,包括通过Huggingface类调用,以及采用与 esm github 相同的方式。用户可以任选其一使用。
✨ 主要特性
- 支持通过Huggingface类调用模型。
- 支持以与esm相同的方式加载模型。
- 提供预测蛋白质序列突变效应的功能。
- 可生成蛋白质嵌入。
📦 安装指南
文档未提及具体安装步骤,故跳过此章节。
💻 使用示例
基础用法
Huggingface模型调用
以下代码展示了如何加载Huggingface模型:
from transformers import EsmTokenizer, EsmForMaskedLM
model_path = "/your/path/to/SaProt_650M_AF2"
tokenizer = EsmTokenizer.from_pretrained(model_path)
model = EsmForMaskedLM.from_pretrained(model_path)
device = "cuda"
model.to(device)
seq = "M#EvVpQpL#VyQdYaKv"
tokens = tokenizer.tokenize(seq)
print(tokens)
inputs = tokenizer(seq, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
outputs = model(**inputs)
print(outputs.logits.shape)
"""
['M#', 'Ev', 'Vp', 'Qp', 'L#', 'Vy', 'Qd', 'Ya', 'Kv']
torch.Size([1, 11, 446])
"""
esm模型调用
esm版本的模型也存储在同一文件夹中,名为 SaProt_650M_AF2.pt
。我们提供了一个函数来加载该模型:
from utils.esm_loader import load_esm_saprot
model_path = "/your/path/to/SaProt_650M_AF2.pt"
model, alphabet = load_esm_saprot(model_path)
高级用法
预测突变效应
我们提供了一个函数来预测蛋白质序列的突变效应。以下示例展示了如何预测特定位置的突变效应。如果使用AF2结构,强烈建议添加pLDDT掩码(见下文):
from model.saprot.saprot_foldseek_mutation_model import SaprotFoldseekMutationModel
config = {
"foldseek_path": None,
"config_path": "/your/path/to/SaProt_650M_AF2",
"load_pretrained": True,
}
model = SaprotFoldseekMutationModel(**config)
tokenizer = model.tokenizer
device = "cuda"
model.eval()
model.to(device)
seq = "M#EvVpQpL#VyQdYaKv"
mut_info = "V3A"
mut_value = model.predict_mut(seq, mut_info)
print(mut_value)
mut_info = "V3A:Q4M"
mut_value = model.predict_mut(seq, mut_info)
print(mut_value)
mut_pos = 3
mut_dict = model.predict_pos_mut(seq, mut_pos)
print(mut_dict)
mut_pos = 3
mut_dict = model.predict_pos_prob(seq, mut_pos)
print(mut_dict)
获取蛋白质嵌入
如果想生成蛋白质嵌入,可以参考以下代码。嵌入是最后一层隐藏状态的平均值:
from model.saprot.base import SaprotBaseModel
from transformers import EsmTokenizer
config = {
"task": "base",
"config_path": "/your/path/to/SaProt_650M_AF2",
"load_pretrained": True,
}
model = SaprotBaseModel(**config)
tokenizer = EsmTokenizer.from_pretrained(config["config_path"])
device = "cuda"
model.to(device)
seq = "M#EvVpQpL#VyQdYaKv"
tokens = tokenizer.tokenize(seq)
print(tokens)
inputs = tokenizer(seq, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
embeddings = model.get_hidden_states(inputs, reduction="mean")
print(embeddings[0].shape)
📚 详细文档
文档未提供详细说明,故跳过此章节。
🔧 技术细节
文档未提供技术实现细节,故跳过此章节。
📄 许可证
本项目采用MIT许可证。