🚀 Ankh3蛋白质语言模型
Ankh3是一个蛋白质语言模型,它在两个目标上进行了联合优化,可用于特征提取,为蛋白质相关研究提供支持。
🚀 快速开始
模型信息
属性 |
详情 |
库名称 |
transformers |
许可证 |
cc - by - nc - sa - 4.0 |
任务类型 |
特征提取 |
模型类型 |
蛋白质语言模型 |
训练数据 |
UniRef50 |
模型细节
Ankh3是一个蛋白质语言模型,它在两个目标上进行了联合优化:
1. 掩码语言建模
- 此任务的思路是,通过对输入蛋白质序列中一定比例(X%)的单个标记(氨基酸)进行掩码处理,故意“损坏”输入的蛋白质序列,然后训练模型来重构原始序列。
- 以下是一个蛋白质序列在损坏前后的示例:
- 原始蛋白质序列:MKAYVLINSRGP
- 该序列将使用哨兵标记进行掩码/损坏,如下所示:
损坏后的序列:M <extra_id_0> A Y <extra_id_1> L I <extra_id_2> S R G <extra_id_3>
- 解码器学习将每个哨兵标记对应到被掩码的实际氨基酸。在这个例子中:<extra_id_0> K 表示 <extra_id_0> 对应于 “K” 氨基酸,依此类推。
- 解码器输出:<extra_id_0> K <extra_id_1> V <extra_id_2> N <extra_id_3> P
2. 蛋白质序列补全
- 此任务的思路是将输入序列切成两段,第一段输入到编码器,解码器的任务是根据编码器输出的第一段表示,自回归地生成第二段。
- 以下是蛋白质序列补全的示例:
- 原始序列:MKAYVLINSRGP
- 我们将 “MKAYVL” 输入到编码器,解码器经过训练,在给定编码器提供的第一部分表示的情况下,应该输出第二部分,即:“INSRGP”
💻 使用示例
基础用法 - 嵌入提取
from transformers import T5ForConditionalGeneration, T5Tokenizer, T5EncoderModel
import torch
sequence = "MDTAYPREDTRAPTPSKAGAHTALTLGAPHPPPRDHLIWSVFSTLYLNLCCLGFLALAYSIKARDQKVVGDLEAARRFGSKAKCYNILAAMWTLVPPLLLLGLVVTGALHLARLAKDSAAFFSTKFDDADYD"
ckpt = "ElnaggarLab/ankh3-xl"
tokenizer = T5Tokenizer.from_pretrained(ckpt)
encoder_model = T5EncoderModel.from_pretrained(ckpt).eval()
nlu_sequence = "[NLU]" + sequence
encoded_nlu_sequence = tokenizer(nlu_sequence, add_special_tokens=True, return_tensors="pt", is_split_into_words=False)
with torch.no_grad():
embedding = encoder_model(**encoded_nlu_sequence)
高级用法 - 序列补全
from transformers import T5ForConditionalGeneration, T5Tokenizer
from transformers.generation import GenerationConfig
import torch
sequence = "MDTAYPREDTRAPTPSKAGAHTALTLGAPHPPPRDHLIWSVFSTLYLNLCCLGFLALAYSIKARDQKVVGDLEAARRFGSKAKCYNILAAMWTLVPPLLLLGLVVTGALHLARLAKDSAAFFSTKFDDADYD"
ckpt = "ElnaggarLab/ankh3-xl"
tokenizer = T5Tokenizer.from_pretrained(ckpt)
model = T5ForConditionalGeneration.from_pretrained(ckpt).eval()
half_length = int(len(sequence) * 0.5)
s2s_sequence = "[S2S]" + sequence[:half_length]
encoded_s2s_sequence = tokenizer(s2s_sequence, add_special_tokens=True, return_tensors="pt", is_split_into_words=False)
gen_config = GenerationConfig(min_length=half_length + 1, max_length=half_length + 1, do_sample=False, num_beams=1)
generated_sequence = model.generate(encoded_s2s_sequence["input_ids"], gen_config, )
predicted_sequence = sequence[:half_length] + tokenizer.batch_decode(generated_sequence)[0]
📄 许可证
本项目采用cc - by - nc - sa - 4.0许可证。