Protst Esm1b
ProtST 框架通過生物醫學文本增強蛋白質序列的預訓練和理解,構建了 ProtDescribe 數據集,設計了三種預訓練任務,支持監督學習和零樣本預測。
下載量 173
發布時間 : 1/2/2024
模型概述
ProtST 是一個結合生物醫學文本和蛋白質序列的預訓練框架,旨在增強蛋白質語言模型的功能理解和表示能力。
模型特點
多模態預訓練
結合蛋白質序列和生物醫學文本進行預訓練,增強模型對蛋白質功能的理解。
零樣本預測
支持零樣本分類任務,無需額外訓練即可進行蛋白質功能預測。
高性能表示學習
在多種表示學習基準上優於現有模型,尤其在功能預測任務中表現突出。
模型能力
蛋白質序列表示學習
蛋白質功能預測
零樣本分類
多模態對齊
使用案例
生物醫學研究
蛋白質亞細胞定位預測
預測蛋白質在細胞中的定位,如細胞核、線粒體等。
在零樣本任務中表現出優越性能。
蛋白質功能註釋
自動為蛋白質序列添加功能描述。
通過多模態對齊提高註釋準確性。
🚀 ProtST:基於生物醫學文本增強蛋白質序列預訓練與理解
ProtST 框架旨在通過生物醫學文本增強蛋白質序列的預訓練和理解。現有蛋白質語言模型主要基於序列學習蛋白質表示,雖能捕捉共進化信息,但難以明確獲取蛋白質功能。ProtST 構建了 ProtDescribe 數據集,用蛋白質功能和其他重要屬性的文本描述擴充蛋白質序列,設計了三種預訓練任務,在下游任務中支持監督學習和零樣本預測,在多種表示學習基準上展現出優越性。
🚀 快速開始
ProtST 框架利用生物醫學文本增強蛋白質序列的預訓練和理解。在預訓練階段,設計了三種類型的任務,包括單模態掩碼預測、多模態表示對齊和多模態掩碼預測,以不同粒度的蛋白質屬性信息增強蛋白質語言模型(PLM),同時保留其原始表示能力。在下游任務中,ProtST 支持監督學習和零樣本預測。
✨ 主要特性
- 數據集構建:構建了 ProtDescribe 數據集,用蛋白質功能和其他重要屬性的文本描述擴充蛋白質序列。
- 預訓練任務設計:設計了三種預訓練任務,增強 PLM 對蛋白質屬性信息的學習。
- 多任務支持:在下游任務中支持監督學習和零樣本預測。
- 性能優越:在多種表示學習基準上,ProtST 誘導的 PLM 優於先前的模型。
💻 使用示例
基礎用法
以下腳本展示瞭如何在零樣本分類任務中使用 Gaudi 運行 ProtST:
import logging
import functools
from tqdm import tqdm
import torch
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer, AutoConfig
import habana_frameworks.torch
logger = logging.getLogger(__name__)
def tokenize_protein(example, protein_tokenizer=None, padding=None):
protein_seqs = example["prot_seq"]
protein_inputs = protein_tokenizer(protein_seqs, padding="max_length", truncation=True, add_special_tokens=True, max_length=1024)
example["protein_input_ids"] = protein_inputs.input_ids
example["protein_attention_mask"] = protein_inputs.attention_mask
return example
def label_embedding(labels, text_tokenizer, text_model, device):
# embed label descriptions
label_feature = []
with torch.inference_mode():
for label in labels:
label_input_ids = text_tokenizer.encode(label, max_length=128,
truncation=True, add_special_tokens=False, padding="max_length")
label_input_ids = [text_tokenizer.cls_token_id] + label_input_ids
label_input_ids = torch.tensor(label_input_ids, dtype=torch.long, device=device).unsqueeze(0)
attention_mask = label_input_ids != text_tokenizer.pad_token_id
attention_mask = attention_mask.to(device)
text_outputs = text_model(label_input_ids, attention_mask=attention_mask)
label_feature.append(text_outputs["text_feature"].clone())
label_feature = torch.cat(label_feature, dim=0)
label_feature = label_feature / label_feature.norm(dim=-1, keepdim=True)
return label_feature
def zero_shot_eval(logger, device,
test_dataset, target_field, protein_model, logit_scale, label_feature):
# get prediction and target
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
preds, targets = [], []
with torch.inference_mode():
for data in tqdm(test_dataloader):
target = data[target_field]
targets.append(target)
protein_input_ids = torch.tensor(data["protein_input_ids"], dtype=torch.long, device=device).unsqueeze(0)
attention_mask = torch.tensor(data["protein_attention_mask"], dtype=torch.long, device=device).unsqueeze(0)
protein_outputs = protein_model(protein_input_ids, attention_mask=attention_mask)
protein_feature = protein_outputs["protein_feature"]
protein_feature = protein_feature / protein_feature.norm(dim=-1, keepdim=True)
pred = logit_scale * protein_feature @ label_feature.t()
preds.append(pred)
preds = torch.cat(preds, dim=0)
targets = torch.tensor(targets, dtype=torch.long, device=device)
accuracy = (preds.argmax(dim=-1) == targets).float().mean().item()
logger.warning("Zero-shot accuracy: %.6f" % accuracy)
if __name__ == "__main__":
# get datasets
raw_datasets = load_dataset("mila-intel/ProtST-SubcellularLocalization", cache_dir="~/.cache/huggingface/datasets", split='test') # cache_dir defaults to "~/.cache/huggingface/datasets"
device = torch.device("hpu")
protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
protein_model = protst_model.protein_model
text_model = protst_model.text_model
logit_scale = protst_model.logit_scale
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
protein_model = wrap_in_hpu_graph(protein_model)
text_model = wrap_in_hpu_graph(text_model)
logit_scale.requires_grad = False
logit_scale = logit_scale.to(device)
logit_scale = logit_scale.exp()
protein_tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
text_tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
func_tokenize_protein = functools.partial(tokenize_protein, protein_tokenizer=protein_tokenizer, padding=False)
test_dataset = raw_datasets.map(
func_tokenize_protein, batched=False,
remove_columns=["prot_seq"],
desc="Running tokenize_proteins on dataset",
)
labels = load_dataset("mila-intel/subloc_template", cache_dir="~/.cache/huggingface/datasets")["train"]["name"]
text_tokenizer.encode(labels[0], max_length=128, truncation=True, add_special_tokens=False)
label_feature = label_embedding(labels, text_tokenizer, text_model, device)
zero_shot_eval(logger, device, test_dataset, "localization",
protein_model, logit_scale, label_feature)
高級用法
在 CPU 上使用 optimum-intel 優化運行 ProtST:
...
protst_model = AutoModel.from_pretrained("mila-intel/ProtST-esm1b", trust_remote_code=True, torch_dtype=torch.bfloat16).to(device)
protein_model = protst_model.protein_model
import intel_extension_for_pytorch as ipex
from optimum.intel.generation.modeling import jit_trace
protein_model = ipex.optimize(protein_model, dtype=torch.bfloat16, inplace=True)
protein_model = jit_trace(protein_model, "sequence-classification")
...
📚 詳細文檔
性能表現
零樣本 ProtST - ESM - 1b 優於少樣本分類器。
模型架構
代碼倉庫
源代碼和模型權重可在 https://github.com/DeepGraphLearning/ProtST 獲取。
Esm2 T36 3B UR50D
MIT
ESM-2是基於掩碼語言建模目標訓練的新一代蛋白質模型,適用於各類以蛋白質序列為輸入的下游任務微調。
蛋白質模型
Transformers

E
facebook
3.5M
22
Esm2 T6 8M UR50D
MIT
ESM-2是基於掩碼語言建模目標訓練的新一代蛋白質模型,適用於對蛋白質序列進行各類任務的微調。
蛋白質模型
Transformers

E
facebook
1.5M
21
Esm2 T33 650M UR50D
MIT
ESM-2是基於掩碼語言建模目標訓練的最先進蛋白質模型,適用於對蛋白質序列進行分析和預測任務
蛋白質模型
Transformers

E
facebook
640.23k
41
Esm2 T12 35M UR50D
MIT
ESM-2是基於掩碼語言建模目標訓練的前沿蛋白質模型,適用於各類蛋白質序列分析任務
蛋白質模型
Transformers

E
facebook
332.83k
15
Prot Bert
基於BERT架構的蛋白質序列預訓練模型,通過自監督學習捕捉蛋白質序列的生物物理特性
蛋白質模型
Transformers

P
Rostlab
276.10k
111
Prostt5
MIT
ProstT5是一種蛋白質語言模型,能夠在蛋白質序列與結構之間進行翻譯。
蛋白質模型
Transformers

P
Rostlab
252.91k
23
Prot T5 Xl Uniref50
基於T5-3B架構的蛋白質序列預訓練模型,通過自監督學習捕捉蛋白質的生物物理特性
蛋白質模型
Transformers

P
Rostlab
78.45k
44
Esm2 T30 150M UR50D
MIT
ESM-2是基於遮蔽語言建模目標訓練的最先進蛋白質模型,適用於對各類以蛋白質序列為輸入的任務進行微調。
蛋白質模型
Transformers

E
facebook
69.91k
7
Prot Bert Bfd
基於Bert架構的蛋白質序列預訓練模型,通過自監督學習從21億蛋白質序列中提取生物物理特徵
蛋白質模型
Transformers

P
Rostlab
30.60k
16
Esm1b T33 650M UR50S
MIT
ESM-1b是基於Transformer的蛋白質語言模型,通過無監督學習蛋白質序列數據,可用於蛋白質結構和功能預測。
蛋白質模型
Transformers

E
facebook
24.20k
18
精選推薦AI模型
Llama 3 Typhoon V1.5x 8b Instruct
專為泰語設計的80億參數指令模型,性能媲美GPT-3.5-turbo,優化了應用場景、檢索增強生成、受限生成和推理任務
大型語言模型
Transformers 支持多種語言

L
scb10x
3,269
16
Cadet Tiny
Openrail
Cadet-Tiny是一個基於SODA數據集訓練的超小型對話模型,專為邊緣設備推理設計,體積僅為Cosmo-3B模型的2%左右。
對話系統
Transformers 英語

C
ToddGoldfarb
2,691
6
Roberta Base Chinese Extractive Qa
基於RoBERTa架構的中文抽取式問答模型,適用於從給定文本中提取答案的任務。
問答系統 中文
R
uer
2,694
98