đ ESM++
ESM++ is a faithful implementation of ESMC (license). It allows for batching and standard Huggingface compatibility without requiring the ESM Python package. The small version corresponds to the 300 million parameter version of ESMC.
đ Quick Start
đĸ Important Note
There was previously a bug with Huggingface weight tieing that caused the logits of ESM++ to differ from ESMC. That bug is now resolved.
đģ Usage Examples
đ Basic Usage
from transformers import AutoModelForMaskedLM
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True)
tokenizer = model.tokenizer
sequences = ['MPRTEIN', 'MSEQWENCE']
tokenized = tokenizer(sequences, padding=True, return_tensors='pt')
output = model(**tokenized)
print(output.logits.shape)
print(output.last_hidden_state.shape)
print(output.loss)
đ Advanced Usage
ESM++ also supports sequence and token level classification tasks like ESM2. Simply pass the number of labels during initialization.
from transformers import AutoModelForSequenceClassification, AutoModelForTokenClassification
model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
logits = model(**tokenized).logits
print(logits.shape)
đĻ Embedding Datasets
import torch
model = AutoModelForMaskedLM.from_pretrained('Synthyra/ESMplusplus_small', trust_remote_code=True, torch_dtype=torch.float16)
embedding_dict = model.embed_dataset(
sequences=[
'MALWMRLLPLLALLALWGPDPAAA', ...
],
tokenizer=model.tokenizer,
batch_size=2,
max_len=512,
full_embeddings=False,
embed_dtype=torch.float32,
pooling_types=['mean', 'cls'],
num_workers=0,
sql=False,
sql_db_path='embeddings.db',
save=True,
save_path='embeddings.pth',
)
model.embed_dataset()
Args:
sequences: List of protein sequences
batch_size: Batch size for processing
max_len: Maximum sequence length
full_embeddings: Whether to return full residue-wise (True) embeddings or pooled (False)
pooling_type: Type of pooling ('mean' or 'cls')
num_workers: Number of workers for data loading, 0 for the main process
sql: Whether to store embeddings in SQLite database - will be stored in float32
sql_db_path: Path to SQLite database
Returns:
Dictionary mapping sequences to embeddings, or None if sql=True
Note:
- If sql=True, embeddings can only be stored in float32
- sql is ideal if you need to stream a very large dataset for training in real-time
- save=True is ideal if you can store the entire embedding dictionary in RAM
- sql will be used if it is True and save is True or False
- If your sql database or .pth file is already present, they will be scanned first for already embedded sequences
- Sequences will be truncated to max_len and sorted by length in descending order for faster processing
đ§ Fine-tuning
from peft import LoraConfig, get_peft_model
model = AutoModelForSequenceClassification.from_pretrained('Synthyra/ESMplusplus_small', num_labels=2, trust_remote_code=True)
target_modules = ["layernorm_qkv.1", "out_proj", "query", "key", "value", "dense"]
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.01,
bias="none",
target_modules=target_modules,
)
model = get_peft_model(model, lora_config)
for param in model.classifier.parameters():
param.requires_grad = True
For a more thorough example of fine-tuning, check out our example script here.
đ Returning Attention Maps
output = model(**tokenized, output_attentions=True)
att = output.attentions
len(att)
â ī¸ Important Note
Usually F.scaled_dot_product_attention is used for the attention calculations, which is much faster than native PyTorch. However, it cannot return attention maps. ESM++ has the option to output_attentions
, which will calculate attention manually. This is much slower, so do not use unless you need the attention maps.
đ Comparison across Floating-point Precision and Implementations
We measured the difference of the last hidden states of the fp32 weights vs. fp16 or bf16. We find that the fp16 is closer to the fp32 outputs, so we recommend loading in fp16.
Comparison |
Average MSE |
FP32 vs. FP16 |
0.00000003 |
FP32 vs. BF16 |
0.00000140 |
We also measured the difference between the outputs of ESM++ vs. ESMC (both in bfloat16) on 1000 random sequences to ensure compliance with the ESM package. The average MSE of the last hidden state is 7.74e - 10.
You can load the weights from the ESM package instead of transformers by replacing .from_pretrained(...) to .from_pretrained_esm('esmc_300m').
đ Model Probes
We employ linear probing techniques on various PLMs and standard datasets, similar our previous paper, to assess the intrinsic correlation between pooled hidden states and valuable properties. ESMC (and thus ESM++) perform very well.
The plot below showcases performance normalized between the negative control (random vector embeddings) and the best performer. Classification task scores are averaged between MCC and F1 (or F1max for multilabel) and regression tasks are averaged between Spearman rho and R2.

âąī¸ Inference Speeds
We look at various ESM models and their throughput on an H100. Adding efficient batching between ESMC and ESM++ significantly improves the throughput, although ESM++ is also faster than ESMC for batch size one. ESM++ small is even faster than ESM2 - 35M with long sequences! The most gains will be seen with PyTorch > 2.5 on linux machines.

đ Citation
If you use any of this implementation or work please cite it (as well as the ESMC preprint).
@misc {ESMPlusPlus,
author = { Hallee, L. and Bichara, D. and Gleghorn, J, P. },
title = { ESMPlusPlus },
year = 2024,
url = { https://huggingface.co/Synthyra/ESMplusplus_small },
doi = { 10.57967/hf/3725 },
publisher = { Hugging Face }
}