Protst Esm1b
ProtST 框架通过生物医学文本增强蛋白质序列的预训练和理解,构建了 ProtDescribe 数据集,设计了三种预训练任务,支持监督学习和零样本预测。
下载量 173
发布时间 : 1/2/2024
模型简介
ProtST 是一个结合生物医学文本和蛋白质序列的预训练框架,旨在增强蛋白质语言模型的功能理解和表示能力。
模型特点
多模态预训练
结合蛋白质序列和生物医学文本进行预训练,增强模型对蛋白质功能的理解。
零样本预测
支持零样本分类任务,无需额外训练即可进行蛋白质功能预测。
高性能表示学习
在多种表示学习基准上优于现有模型,尤其在功能预测任务中表现突出。
模型能力
蛋白质序列表示学习
蛋白质功能预测
零样本分类
多模态对齐
使用案例
生物医学研究
蛋白质亚细胞定位预测
预测蛋白质在细胞中的定位,如细胞核、线粒体等。
在零样本任务中表现出优越性能。
蛋白质功能注释
自动为蛋白质序列添加功能描述。
通过多模态对齐提高注释准确性。
🚀 ProtST:基于生物医学文本增强蛋白质序列预训练与理解
ProtST 框架旨在通过生物医学文本增强蛋白质序列的预训练和理解。现有蛋白质语言模型主要基于序列学习蛋白质表示,虽能捕捉共进化信息,但难以明确获取蛋白质功能。ProtST 构建了 ProtDescribe 数据集,用蛋白质功能和其他重要属性的文本描述扩充蛋白质序列,设计了三种预训练任务,在下游任务中支持监督学习和零样本预测,在多种表示学习基准上展现出优越性。
🚀 快速开始
ProtST 框架利用生物医学文本增强蛋白质序列的预训练和理解。在预训练阶段,设计了三种类型的任务,包括单模态掩码预测、多模态表示对齐和多模态掩码预测,以不同粒度的蛋白质属性信息增强蛋白质语言模型(PLM),同时保留其原始表示能力。在下游任务中,ProtST 支持监督学习和零样本预测。
✨ 主要特性
- 数据集构建:构建了 ProtDescribe 数据集,用蛋白质功能和其他重要属性的文本描述扩充蛋白质序列。
- 预训练任务设计:设计了三种预训练任务,增强 PLM 对蛋白质属性信息的学习。
- 多任务支持:在下游任务中支持监督学习和零样本预测。
- 性能优越:在多种表示学习基准上,ProtST 诱导的 PLM 优于先前的模型。
💻 使用示例
基础用法
以下脚本展示了如何在零样本分类任务中使用 Gaudi 运行 ProtST:
import logging
import functools
from tqdm import tqdm
import torch
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer, AutoConfig
import habana_frameworks.torch
logger = logging.getLogger(__name__)
def tokenize_protein(example, protein_tokenizer=None, padding=None):
protein_seqs = example["prot_seq"]
protein_inputs = protein_tokenizer(protein_seqs, padding="max_length", truncation=True, add_special_tokens=True, max_length=1024)
example["protein_input_ids"] = protein_inputs.input_ids
example["protein_attention_mask"] = protein_inputs.attention_mask
return example
def label_embedding(labels, text_tokenizer, text_model, device):
# embed label descriptions
label_feature = []
with torch.inference_mode():
for label in labels:
label_input_ids = text_tokenizer.encode(label, max_length=128,
truncation=True, add_special_tokens=False, padding="max_length")
label_input_ids = [text_tokenizer.cls_token_id] + label_input_ids
label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
attention_mask = label_input_ids != text_tokenizer.pad_token_id
attention_mask = attention_mask.to(device)
text_outputs = text_model(label_input_ids, attention_mask=attention_mask)
label_feature.append(text_outputs["text_feature"].clone())
label_feature = torch.cat(label_feature, dim=0)
label_feature = label_feature / label_feature.norm(dim=-1, keepdim=True)
return label_feature
def zero_shot_eval(logger, device,
test_dataset, target_field, protein_model, logit_scale, label_feature):
# get prediction and target
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
preds, targets = [], []
with torch.inference_mode():
for data in tqdm(test_dataloader):
target = data[target_field]
targets.append(target)
protein_input_ids = torch.tensor(data["protein_input_ids"], dtype=torch.long, device=device).unsqueeze(0)
attention_mask = torch.tensor(data["protein_attention_mask"], dtype=torch.long, device=device).unsqueeze(0)
protein_outputs = protein_model(protein_input_ids, attention_mask=attention_mask)
protein_feature = protein_outputs["protein_feature"]
protein_feature = protein_feature / protein_feature.norm(dim=-1, keepdim=True)
pred = logit_scale * protein_feature @ label_feature.t()
preds.append(pred)
preds = torch.cat(preds, dim=0)
targets = torch.tensor(targets, dtype=torch.long, device=device)
accuracy = (preds.argmax(dim=-1) == targets).float().mean().item()
logger.warning("Zero-shot accuracy: %.6f" % accuracy)
if __name__ == "__main__":
# get datasets
raw_datasets = load_dataset("mila-intel/ProtST-SubcellularLocalization", cache_dir="~/.cache/huggingface/datasets", split='test') # cache_dir defaults to "~/.cache/huggingface/datasets"
device = torch.device("hpu")
protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
protein_model = protst_model.protein_model
text_model = protst_model.text_model
logit_scale = protst_model.logit_scale
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
protein_model = wrap_in_hpu_graph(protein_model)
text_model = wrap_in_hpu_graph(text_model)
logit_scale.requires_grad = False
logit_scale = logit_scale.to(device)
logit_scale = logit_scale.exp()
protein_tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
text_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
func_tokenize_protein = functools.partial(tokenize_protein, protein_tokenizer=protein_tokenizer, padding=False)
test_dataset = raw_datasets.map(
func_tokenize_protein, batched=False,
remove_columns=["prot_seq"],
desc="Running tokenize_proteins on dataset",
)
labels = load_dataset("mila-intel/subloc_template", cache_dir="~/.cache/huggingface/datasets")["train"]["name"]
text_tokenizer.encode(labels[0], max_length=128, truncation=True, add_special_tokens=False)
label_feature = label_embedding(labels, text_tokenizer, text_model, device)
zero_shot_eval(logger, device, test_dataset, "localization",
protein_model, logit_scale, label_feature)
高级用法
在 CPU 上使用 optimum-intel 优化运行 ProtST:
...
protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
protein_model = protst_model.protein_model
import intel_extension_for_pytorch as ipex
from optimum.intel.generation.modeling import jit_trace
protein_model = ipex.optimize(protein_model, dtype=torch.bfloat16, inplace=True)
protein_model = jit_trace(protein_model, "sequence-classification")
...
📚 详细文档
性能表现
零样本 ProtST - ESM - 1b 优于少样本分类器。
模型架构
代码仓库
源代码和模型权重可在 https://github.com/DeepGraphLearning/ProtST 获取。
Esm2 T36 3B UR50D
MIT
ESM-2是基于掩码语言建模目标训练的新一代蛋白质模型,适用于各类以蛋白质序列为输入的下游任务微调。
蛋白质模型
Transformers

E
facebook
3.5M
22
Esm2 T6 8M UR50D
MIT
ESM-2是基于掩码语言建模目标训练的新一代蛋白质模型,适用于对蛋白质序列进行各类任务的微调。
蛋白质模型
Transformers

E
facebook
1.5M
21
Esm2 T33 650M UR50D
MIT
ESM-2是基于掩码语言建模目标训练的最先进蛋白质模型,适用于对蛋白质序列进行分析和预测任务
蛋白质模型
Transformers

E
facebook
640.23k
41
Esm2 T12 35M UR50D
MIT
ESM-2是基于掩码语言建模目标训练的前沿蛋白质模型,适用于各类蛋白质序列分析任务
蛋白质模型
Transformers

E
facebook
332.83k
15
Prot Bert
基于BERT架构的蛋白质序列预训练模型,通过自监督学习捕捉蛋白质序列的生物物理特性
蛋白质模型
Transformers

P
Rostlab
276.10k
111
Prostt5
MIT
ProstT5是一种蛋白质语言模型,能够在蛋白质序列与结构之间进行翻译。
蛋白质模型
Transformers

P
Rostlab
252.91k
23
Prot T5 Xl Uniref50
基于T5-3B架构的蛋白质序列预训练模型,通过自监督学习捕捉蛋白质的生物物理特性
蛋白质模型
Transformers

P
Rostlab
78.45k
44
Esm2 T30 150M UR50D
MIT
ESM-2是基于遮蔽语言建模目标训练的最先进蛋白质模型,适用于对各类以蛋白质序列为输入的任务进行微调。
蛋白质模型
Transformers

E
facebook
69.91k
7
Prot Bert Bfd
基于Bert架构的蛋白质序列预训练模型,通过自监督学习从21亿蛋白质序列中提取生物物理特征
蛋白质模型
Transformers

P
Rostlab
30.60k
16
Esm1b T33 650M UR50S
MIT
ESM-1b是基于Transformer的蛋白质语言模型,通过无监督学习蛋白质序列数据,可用于蛋白质结构和功能预测。
蛋白质模型
Transformers

E
facebook
24.20k
18
精选推荐AI模型
Llama 3 Typhoon V1.5x 8b Instruct
专为泰语设计的80亿参数指令模型,性能媲美GPT-3.5-turbo,优化了应用场景、检索增强生成、受限生成和推理任务
大型语言模型
Transformers 支持多种语言

L
scb10x
3,269
16
Cadet Tiny
Openrail
Cadet-Tiny是一个基于SODA数据集训练的超小型对话模型,专为边缘设备推理设计,体积仅为Cosmo-3B模型的2%左右。
对话系统
Transformers 英语

C
ToddGoldfarb
2,691
6
Roberta Base Chinese Extractive Qa
基于RoBERTa架构的中文抽取式问答模型,适用于从给定文本中提取答案的任务。
问答系统 中文
R
uer
2,694
98