模型概述
模型特點
模型能力
使用案例
🚀 ESM++
ESM++是ESMC(許可證)的忠實實現,它支持批量處理,並且無需ESM Python包即可與標準的Huggingface兼容。小版本對應於ESMC的3億參數版本。該項目解決了在不依賴ESM Python包的情況下,實現與Huggingface的標準兼容以及批量處理的問題,為相關研究和開發提供了便利。
🚀 快速開始
之前Huggingface權重綁定存在一個bug,導致ESM++的對數幾率與ESMC不同。該bug現已修復。
✨ 主要特性
- 忠實實現ESMC,支持批量處理。
- 無需ESM Python包,與標準Huggingface兼容。
- 支持序列和標記級別的分類任務。
- 提供不同浮點精度的加載選項。
- 可返回注意力圖。
- 相比ESMC和其他模型,推理速度更快。
📦 安裝指南
文檔未提及安裝步驟,暫不展示。
💻 使用示例
基礎用法
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
tokenizer = model.tokenizer
sequences = ['MPRTEIN', 'MSEQWENCE']
tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
# tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training
output = model(**tokenized) # get all hidden states with output_hidden_states=True
print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 960)
print(output.loss) # language modeling loss if you passed labels
#print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
高級用法
支持序列和標記級別的分類任務
from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
logits = model(**tokenized).logits
print(logits.shape) # (batch_size, num_labels), (2, 2)
以不同浮點精度加載權重
import torch
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True, torch_dtype=torch.float16) # or torch.bfloat16
嵌入整個數據集
embedding_dict = model.embed_dataset(
sequences=[
'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
],
tokenizer=model.tokenizer,
batch_size=2, # adjust for your GPU memory
max_len=512, # adjust for your needs
full_embeddings=False, # if True, no pooling is performed
embed_dtype=torch.float32, # cast to what dtype you want
pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
sql=False, # if True, embeddings will be stored in SQLite database
sql_db_path='embeddings.db',
save=True, # if True, embeddings will be saved as a .pth file
save_path='embeddings.pth',
)
# embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
model.embed_dataset()
Args:
sequences: List of protein sequences
batch_size: Batch size for processing
max_len: Maximum sequence length
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
pooling_type: Type of pooling ('mean' or 'cls')
num_workers: Number of workers for data loading, 0 for the main process
sql: Whether to store embeddings in SQLite database - will be stored in float32
sql_db_path: Path to SQLite database
Returns:
Dictionary mapping sequences to embeddings, or None if sql=True
Note:
- If sql=True, embeddings can only be stored in float32
- sql is ideal if you need to stream a very large dataset for training in real-time
- save=True is ideal if you can store the entire embedding dictionary in RAM
- sql will be used if it is True and save is True or False
- If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
- Sequences will be truncated to max_len and sorted by length in descending order for faster processing
使用🤗 peft進行微調
model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
# these modules handle ESM++ and ESM2 attention layers
target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"]
lora_config = LoraConfig(
r=8, # choose lora parameters to your liking
lora_alpha=16,
lora_dropout=0.01,
bias="none",
target_modules=target_modules,
)
# Apply LoRA to the model
model = get_peft_model(model, lora_config)
# Unfreeze the classifier head
for param in model.classifier.parameters():
param.requires_grad = True
如需更全面的微調示例,請查看我們的示例腳本此處。
返回注意力圖
output = model(**tokenized, output_attentions=True)
att = output.attentions
len(att) # 30, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each
📚 詳細文檔
不同浮點精度和實現方式的比較
我們測量了fp32權重與fp16或bf16的最後隱藏狀態的差異。發現fp16更接近fp32的輸出,因此建議以fp16加載。請注意,ESM包也以fp32加載ESMC,但默認轉換為bf16,這在推理/訓練中各有優缺點,因此可以根據需要選擇半精度加載方式。
- FP32與FP16的平均均方誤差(MSE):0.00000003
- FP32與BF16的平均均方誤差(MSE):0.00000140
我們還測量了ESM++與ESMC(均為bfloat16)在1000個隨機序列上的輸出差異,以確保與ESM包兼容。
- 最後隱藏狀態的平均均方誤差(MSE):7.74e - 10
你可以通過將.from_pretrained(...)
替換為.from_pretrained_esm('esmc_300m')
,從ESM包而不是transformers加載權重。
模型探針
我們在各種PLM和標準數據集上採用線性探測技術,類似於我們之前的論文,以評估池化隱藏狀態與有價值屬性之間的內在相關性。ESMC(以及ESM++)表現出色。
下圖展示了在負控制(隨機向量嵌入)和最佳表現者之間進行歸一化後的性能。分類任務得分在MCC和F1(或多標籤的F1max)之間平均,迴歸任務在Spearman rho和R2之間平均。
推理速度
我們研究了各種ESM模型在H100上的吞吐量。在ESMC和ESM++之間添加高效的批量處理顯著提高了吞吐量,即使在批量大小為1的情況下,ESM++也比ESMC更快。ESM++小版本在處理長序列時甚至比ESM2 - 35M更快!在Linux機器上使用PyTorch > 2.5時,提升最為明顯。
🔧 技術細節
文檔未提及技術實現細節,暫不展示。
📄 許可證
如果你使用此實現或相關工作,請引用它(以及ESMC預印本)。
@misc {ESMPlusPlus,
author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
title = { ESMPlusPlus },
year = 2024,
url = { https://huggingface.co/Synthyra/ESMplusplus_small },
doi = { 10.57967/hf/3725 },
publisher = { Hugging Face }
}











