🚀 剪枝BERT模型
本模型在BookCorpus数据集上使用知识蒸馏技术进行预训练。它与BERT采用相同的架构,但隐藏层大小为256(仅为BERT隐藏层大小的三分之一),且具有4个注意力头(与BERT的头大小相同)。该模型的权重通过对bert-base-uncased
的权重进行剪枝初始化,并使用多种损失函数进行知识蒸馏以微调模型。此外,该模型使用的分词器与bert-base-uncased
相同。
🚀 快速开始
加载模型和分词器
from transformers import AutoModelForMaskedLM, BertTokenizer
model_name = "eli4s/prunedBert-L12-h256-A4-finetuned"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
对句子进行预测
import torch
sentence = "Let's have a [MASK]."
model.eval()
inputs = tokenizer([sentence], padding='longest', return_tensors='pt')
output = model(inputs['input_ids'], attention_mask=inputs['attention_mask'])
mask_index = inputs['input_ids'].tolist()[0].index(103)
masked_token = output['logits'][0][mask_index].argmax(axis=-1)
predicted_token = tokenizer.decode(masked_token)
print(predicted_token)
预测最相关的前n个结果
top_n = 5
vocab_size = model.config.vocab_size
logits = output['logits'][0][mask_index].tolist()
top_tokens = sorted(list(range(vocab_size)), key=lambda i:logits[i], reverse=True)[:top_n]
tokenizer.decode(top_tokens)
✨ 主要特性
- 轻量级架构:隐藏层大小仅为256,是BERT隐藏层大小的三分之一,减少了计算资源的需求。
- 知识蒸馏微调:使用多种损失函数进行知识蒸馏,在保留性能的同时优化了模型。
- 通用分词器:使用与
bert-base-uncased
相同的分词器,方便集成到现有系统中。
🔧 技术细节
本模型的训练过程分为两个主要步骤:权重初始化和知识蒸馏微调。首先,通过对bert-base-uncased
的权重进行剪枝,初始化模型的权重。然后,使用多种损失函数进行知识蒸馏,对模型进行微调。这种方法使得模型在保持相对较小的参数规模的同时,仍能取得较好的性能。
在架构方面,模型与BERT相同,但隐藏层大小为256,注意力头数量为4,这使得模型在计算效率和性能之间取得了平衡。
💻 使用示例
基础用法
from transformers import AutoModelForMaskedLM, BertTokenizer
model_name = "eli4s/prunedBert-L12-h256-A4-finetuned"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
import torch
sentence = "Let's have a [MASK]."
model.eval()
inputs = tokenizer([sentence], padding='longest', return_tensors='pt')
output = model(inputs['input_ids'], attention_mask=inputs['attention_mask'])
mask_index = inputs['input_ids'].tolist()[0].index(103)
masked_token = output['logits'][0][mask_index].argmax(axis=-1)
predicted_token = tokenizer.decode(masked_token)
print(predicted_token)
高级用法
top_n = 5
vocab_size = model.config.vocab_size
logits = output['logits'][0][mask_index].tolist()
top_tokens = sorted(list(range(vocab_size)), key=lambda i:logits[i], reverse=True)[:top_n]
tokenizer.decode(top_tokens)