モデル概要
モデル特徴
モデル能力
使用事例
🚀 ESM++
ESM++は、ESMC(ライセンス)の忠実な実装で、ESM Pythonパッケージを必要とせずにバッチ処理と標準的なHuggingface互換性を提供します。大規模バージョンは、ESMCの6億パラメータバージョンに対応しています。
🚀 クイックスタート
以前は、Huggingfaceの重み共有に関するバグがあり、ESM++のロジットがESMCと異なる原因となっていました。このバグは現在解決されています。
✨ 主な機能
ESM++は、ESMCの忠実な実装で、バッチ処理とHuggingfaceとの互換性を提供します。また、ESM2と同様にシーケンスおよびトークンレベルの分類タスクをサポートしています。
📦 インストール
このライブラリを使用するには、transformers
ライブラリが必要です。以下のコマンドでインストールできます。
pip install transformers
💻 使用例
基本的な使用法
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_large', trust_remote_code=True)
tokenizer = model.tokenizer
sequences = ['MPRTEIN', 'MSEQWENCE']
tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
# tokenized['labels'] = tokenized['input_ids'].clone() # correctly mask input_ids and set unmasked instances of labels to -100 for MLM training
output = model(**tokenized) # get all hidden states with output_hidden_states=True
print(output.logits.shape) # language modeling logits, (batch_size, seq_len, vocab_size), (2, 11, 64)
print(output.last_hidden_state.shape) # last hidden state of the model, (batch_size, seq_len, hidden_size), (2, 11, 1152)
print(output.loss) # language modeling loss if you passed labels
#print(output.hidden_states) # all hidden states if you passed output_hidden_states=True (in tuple)
高度な使用法
シーケンス分類タスク
from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
logits = model(**tokenized).logits
print(logits.shape) # (batch_size, num_labels), (2, 2)
データセットの埋め込み
import torch
embedding_dict = model.embed_dataset(
sequences=[
'MALWMRLLPLLALLALWGPDPAAA', ... # list of protein sequences
],
tokenizer=model.tokenizer,
batch_size=2, # adjust for your GPU memory
max_len=512, # adjust for your needs
full_embeddings=False, # if True, no pooling is performed
embed_dtype=torch.float32, # cast to what dtype you want
pooling_types=['mean', 'cls'], # more than one pooling type will be concatenated together
num_workers=0, # if you have many cpu cores, we find that num_workers = 4 is fast for large datasets
sql=False, # if True, embeddings will be stored in SQLite database
sql_db_path='embeddings.db',
save=True, # if True, embeddings will be saved as a .pth file
save_path='embeddings.pth',
)
# embedding_dict is a dictionary mapping sequences to their embeddings as tensors for .pth or numpy arrays for sql
model.embed_dataset()
Args:
sequences: List of protein sequences
batch_size: Batch size for processing
max_len: Maximum sequence length
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
pooling_type: Type of pooling ('mean' or 'cls')
num_workers: Number of workers for data loading, 0 for the main process
sql: Whether to store embeddings in SQLite database - will be stored in float32
sql_db_path: Path to SQLite database
Returns:
Dictionary mapping sequences to embeddings, or None if sql=True
Note:
- If sql=True, embeddings can only be stored in float32
- sql is ideal if you need to stream a very large dataset for training in real-time
- save=True is ideal if you can store the entire embedding dictionary in RAM
- sql will be used if it is True and save is True or False
- If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
- Sequences will be truncated to max_len and sorted by length in descending order for faster processing
微調整
from transformers import AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model
model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_large', num_labels=2, trust_remote_code=True)
# these modules handle ESM++ and ESM2 attention layers
target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"]
lora_config = LoraConfig(
r=8, # choose lora parameters to your liking
lora_alpha=16,
lora_dropout=0.01,
bias="none",
target_modules=target_modules,
)
# Apply LoRA to the model
model = get_peft_model(model, lora_config)
# Unfreeze the classifier head
for param in model.classifier.parameters():
param.requires_grad = True
より詳細な微調整の例については、こちらのスクリプトを参照してください。
アテンションマップの取得
output = model(**tokenized, output_attentions=True)
att = output.attentions
len(att) # 33, one for each layer, size (batch_size, num_heads, seq_len, seq_len) each
🔧 技術詳細
浮動小数点精度と実装の比較
我々は、fp32の重みとfp16またはbf16の最後の隠れ層の状態の差を測定しました。fp16はfp32の出力に近いことがわかったので、fp16でロードすることをお勧めします。
FP16の平均MSE: 0.00000003 BF16の平均MSE: 0.00000122
また、1000のランダムなシーケンスに対するESM++とESMC(両方ともbfloat16)の出力の差を測定し、ESMパッケージとの互換性を確認しました。
最後の隠れ層の平均MSE: 2.46e-09
transformersではなくESMパッケージから重みをロードするには、.from_pretrained(...)
を .from_pretrained_esm('esmc_600m')
に置き換えてください。
モデルプローブ
我々は、以前の論文と同様に、様々なPLMと標準データセットに線形プロービング技術を適用し、プールされた隠れ層の状態と重要な特性との間の内在的な相関を評価しました。ESMC(したがってESM++)は非常に良好な性能を示しました。
下のプロットは、ネガティブコントロール(ランダムベクトル埋め込み)と最良のパフォーマーの間で正規化された性能を示しています。分類タスクのスコアはMCCとF1(またはマルチラベルの場合はF1max)の平均、回帰タスクはスピアマンのrhoとR2の平均です。
推論速度
我々は、H100上での様々なESMモデルとそのスループットを調べました。ESMCとESM++の間で効率的なバッチ処理を追加することで、スループットが大幅に向上します。ただし、バッチサイズが1の場合でもESM++はESMCよりも高速です。ESM++ smallは、長いシーケンスではESM2-35Mよりもさらに高速です!最大のメリットは、LinuxマシンでPyTorch > 2.5を使用した場合に見られます。
📄 ライセンス
このプロジェクトは、ESMCのライセンスに基づいています。
引用
この実装や成果を使用する場合は、以下のように引用してください(ESMCのプレプリントも同様に引用してください)。
@misc {ESMPlusPlus,
author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
title = { ESMPlusPlus },
year = 2024,
url = { https://huggingface.co/Synthyra/ESMplusplus_small },
doi = { 10.57967/hf/3726 },
publisher = { Hugging Face }
}











