模型简介
模型特点
模型能力
使用案例
🚀 ESM++
ESM++ 是 ESMC(许可证)的忠实实现,它支持批量处理,并且无需 ESM Python 包即可与标准的 Huggingface 兼容。大版本对应 ESMC 的 6 亿参数版本。
🚀 快速开始
注意事项
之前 Huggingface 的权重绑定存在一个 bug,导致 ESM++ 的对数几率与 ESMC 不同。该 bug 现已修复。
✨ 主要特性
- 忠实实现 ESMC,支持批量处理和 Huggingface 兼容。
- 支持序列和标记级别的分类任务。
- 支持以不同浮点精度加载权重。
- 支持返回注意力图。
- 可进行微调。
- 提供模型探针评估。
- 具有较高的推理速度。
📦 安装指南
文档未提及安装步骤,故跳过此章节。
💻 使用示例
基础用法
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True)
tokenizer = model.tokenizer
sequences = ['MPRTEIN', 'MSEQWENCE']
tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
# tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training
output = model(**tokenized) # get all hidden states with output_hidden_states=True
print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 1152)
print(output.loss) # language modeling loss if you passed labels
#print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
高级用法
支持序列和标记级别的分类任务
from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
logits = model(**tokenized).logits
print(logits.shape) # (batch_size, num_labels), (2, 2)
以不同浮点精度加载权重
import torch
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True, torch_dtype=torch.float16) # or torch.bfloat16
嵌入整个数据集
embedding_dict = model.embed_dataset(
sequences=[
'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
],
tokenizer=model.tokenizer,
batch_size=2, # adjust for your GPU memory
max_len=512, # adjust for your needs
full_embeddings=False, # if True, no pooling is performed
embed_dtype=torch.float32, # cast to what dtype you want
pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
sql=False, # if True, embeddings will be stored in SQLite database
sql_db_path='embeddings.db',
save=True, # if True, embeddings will be saved as a .pth file
save_path='embeddings.pth',
)
# embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
model.embed_dataset()
Args:
sequences: List of protein sequences
batch_size: Batch size for processing
max_len: Maximum sequence length
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
pooling_type: Type of pooling ('mean' or 'cls')
num_workers: Number of workers for data loading, 0 for the main process
sql: Whether to store embeddings in SQLite database - will be stored in float32
sql_db_path: Path to SQLite database
Returns:
Dictionary mapping sequences to embeddings, or None if sql=True
Note:
- If sql=True, embeddings can only be stored in float32
- sql is ideal if you need to stream a very large dataset for training in real-time
- save=True is ideal if you can store the entire embedding dictionary in RAM
- sql will be used if it is True and save is True or False
- If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
- Sequences will be truncated to max_len and sorted by length in descending order for faster processing
使用 🤗 peft 进行微调
model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
# these modules handle ESM++ and ESM2 attention layers
target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"]
lora_config = LoraConfig(
r=8, # choose lora parameters to your liking
lora_alpha=16,
lora_dropout=0.01,
bias="none",
target_modules=target_modules,
)
# Apply LoRA to the model
model = get_peft_model(model, lora_config)
# Unfreeze the classifier head
for param in model.classifier.parameters():
param.requires_grad = True
返回注意力图
output = model(**tokenized, output_attentions=True)
att = output.attentions
len(att) # 33, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each
📚 详细文档
从 ESM 包加载权重
你可以通过将 .from_pretrained(...)
替换为 .from_pretrained_esm('esmc_600m')
来从 ESM 包而不是 transformers 加载权重。
模型探针
我们采用线性探测技术对各种蛋白质语言模型(PLMs)和标准数据集进行评估,类似于我们之前的 论文,以评估池化隐藏状态与有价值属性之间的内在相关性。ESMC(以及 ESM++)表现非常出色。
推理速度
我们研究了各种 ESM 模型在 H100 上的吞吐量。在 ESMC 和 ESM++ 之间添加高效的批量处理显著提高了吞吐量,尽管 ESM++ 在批量大小为 1 时也比 ESMC 更快。ESM++ 小版本在处理长序列时甚至比 ESM2 - 35M 更快!在 Linux 机器上使用 PyTorch > 2.5 时,收益最为明显。
🔧 技术细节
浮点精度和实现的比较
我们测量了 fp32 权重与 fp16 或 bf16 的最后隐藏状态的差异。我们发现 fp16 更接近 fp32 的输出,因此建议以 fp16 加载。
请注意,ESM 包也以 fp32 加载 ESMC,但默认转换为 bf16,这在推理/训练中各有优缺点 - 因此你可以根据需要选择半精度加载。
FP16 的平均均方误差(MSE):0.00000003
BF16 的平均均方误差(MSE):0.00000122
我们还测量了 ESM++ 与 ESMC(均为 bfloat16)在 1000 个随机序列上的输出差异,以确保与 ESM 包兼容。
最后隐藏状态的平均均方误差(MSE):2.46e - 09
📄 许可证
请参考 ESMC 许可证。
引用
如果你使用了此实现或相关工作,请引用它(以及 ESMC 预印本)。
@misc {ESMPlusPlus,
author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
title = { ESMPlusPlus },
year = 2024,
url = { https://huggingface.co/Synthyra/ESMplusplus_small },
doi = { 10.57967/hf/3726 },
publisher = { Hugging Face }
}
微调示例
如需更详细的微调示例,请查看我们的示例脚本 此处。











