🚀 Ankh3蛋白質語言模型
Ankh3是一個蛋白質語言模型,它在兩個目標上進行了聯合優化,可用於特徵提取,為蛋白質相關研究提供支持。
🚀 快速開始
模型信息
屬性 |
詳情 |
庫名稱 |
transformers |
許可證 |
cc - by - nc - sa - 4.0 |
任務類型 |
特徵提取 |
模型類型 |
蛋白質語言模型 |
訓練數據 |
UniRef50 |
模型細節
Ankh3是一個蛋白質語言模型,它在兩個目標上進行了聯合優化:
1. 掩碼語言建模
- 此任務的思路是,通過對輸入蛋白質序列中一定比例(X%)的單個標記(氨基酸)進行掩碼處理,故意“損壞”輸入的蛋白質序列,然後訓練模型來重構原始序列。
- 以下是一個蛋白質序列在損壞前後的示例:
- 原始蛋白質序列:MKAYVLINSRGP
- 該序列將使用哨兵標記進行掩碼/損壞,如下所示:
損壞後的序列:M <extra_id_0> A Y <extra_id_1> L I <extra_id_2> S R G <extra_id_3>
- 解碼器學習將每個哨兵標記對應到被掩碼的實際氨基酸。在這個例子中:<extra_id_0> K 表示 <extra_id_0> 對應於 “K” 氨基酸,依此類推。
- 解碼器輸出:<extra_id_0> K <extra_id_1> V <extra_id_2> N <extra_id_3> P
2. 蛋白質序列補全
- 此任務的思路是將輸入序列切成兩段,第一段輸入到編碼器,解碼器的任務是根據編碼器輸出的第一段表示,自迴歸地生成第二段。
- 以下是蛋白質序列補全的示例:
- 原始序列:MKAYVLINSRGP
- 我們將 “MKAYVL” 輸入到編碼器,解碼器經過訓練,在給定編碼器提供的第一部分表示的情況下,應該輸出第二部分,即:“INSRGP”
💻 使用示例
基礎用法 - 嵌入提取
from transformers import T5ForConditionalGeneration, T5Tokenizer, T5EncoderModel
import torch
sequence = "MDTAYPREDTRAPTPSKAGAHTALTLGAPHPPPRDHLIWSVFSTLYLNLCCLGFLALAYSIKARDQKVVGDLEAARRFGSKAKCYNILAAMWTLVPPLLLLGLVVTGALHLARLAKDSAAFFSTKFDDADYD"
ckpt = "ElnaggarLab/ankh3-xl"
tokenizer = T5Tokenizer.from_pretrained(ckpt)
encoder_model = T5EncoderModel.from_pretrained(ckpt).eval()
nlu_sequence = "[NLU]" + sequence
encoded_nlu_sequence = tokenizer(nlu_sequence, add_special_tokens=True, return_tensors="pt", is_split_into_words=False)
with torch.no_grad():
embedding = encoder_model(**encoded_nlu_sequence)
高級用法 - 序列補全
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers.generation import GenerationConfig
import torch
sequence = "MDTAYPREDTRAPTPSKAGAHTALTLGAPHPPPRDHLIWSVFSTLYLNLCCLGFLALAYSIKARDQKVVGDLEAARRFGSKAKCYNILAAMWTLVPPLLLLGLVVTGALHLARLAKDSAAFFSTKFDDADYD"
ckpt = "ElnaggarLab/ankh3-xl"
tokenizer = T5Tokenizer.from_pretrained(ckpt)
model = T5ForConditionalGeneration.from_pretrained(ckpt).eval()
half_length = int(len(sequence) * 0.5)
s2s_sequence = "[S2S]" + sequence[:half_length]
encoded_s2s_sequence = tokenizer(s2s_sequence, add_special_tokens=True, return_tensors="pt", is_split_into_words=False)
gen_config = GenerationConfig(min_length=half_length + 1, max_length=half_length + 1, do_sample=False, num_beams=1)
generated_sequence = model.generate(encoded_s2s_sequence["input_ids"], gen_config, )
predicted_sequence = sequence[:half_length] + tokenizer.batch_decode(generated_sequence)[0]
📄 許可證
本項目採用cc - by - nc - sa - 4.0許可證。