🚀 剪枝BERT模型
本模型在BookCorpus數據集上使用知識蒸餾技術進行預訓練。它與BERT採用相同的架構,但隱藏層大小為256(僅為BERT隱藏層大小的三分之一),且具有4個注意力頭(與BERT的頭大小相同)。該模型的權重通過對bert-base-uncased
的權重進行剪枝初始化,並使用多種損失函數進行知識蒸餾以微調模型。此外,該模型使用的分詞器與bert-base-uncased
相同。
🚀 快速開始
加載模型和分詞器
from transformers import AutoModelForMaskedLM, BertTokenizer
model_name = "eli4s/prunedBert-L12-h256-A4-finetuned"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
對句子進行預測
import torch
sentence = "Let's have a [MASK]."
model.eval()
inputs = tokenizer([sentence], padding='longest', return_tensors='pt')
output = model(inputs['input_ids'], attention_mask=inputs['attention_mask'])
mask_index = inputs['input_ids'].tolist()[0].index(103)
masked_token = output['logits'][0][mask_index].argmax(axis=-1)
predicted_token = tokenizer.decode(masked_token)
print(predicted_token)
預測最相關的前n個結果
top_n = 5
vocab_size = model.config.vocab_size
logits = output['logits'][0][mask_index].tolist()
top_tokens = sorted(list(range(vocab_size)), key=lambda i:logits[i], reverse=True)[:top_n]
tokenizer.decode(top_tokens)
✨ 主要特性
- 輕量級架構:隱藏層大小僅為256,是BERT隱藏層大小的三分之一,減少了計算資源的需求。
- 知識蒸餾微調:使用多種損失函數進行知識蒸餾,在保留性能的同時優化了模型。
- 通用分詞器:使用與
bert-base-uncased
相同的分詞器,方便集成到現有系統中。
🔧 技術細節
本模型的訓練過程分為兩個主要步驟:權重初始化和知識蒸餾微調。首先,通過對bert-base-uncased
的權重進行剪枝,初始化模型的權重。然後,使用多種損失函數進行知識蒸餾,對模型進行微調。這種方法使得模型在保持相對較小的參數規模的同時,仍能取得較好的性能。
在架構方面,模型與BERT相同,但隱藏層大小為256,注意力頭數量為4,這使得模型在計算效率和性能之間取得了平衡。
💻 使用示例
基礎用法
from transformers import AutoModelForMaskedLM, BertTokenizer
model_name = "eli4s/prunedBert-L12-h256-A4-finetuned"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
import torch
sentence = "Let's have a [MASK]."
model.eval()
inputs = tokenizer([sentence], padding='longest', return_tensors='pt')
output = model(inputs['input_ids'], attention_mask=inputs['attention_mask'])
mask_index = inputs['input_ids'].tolist()[0].index(103)
masked_token = output['logits'][0][mask_index].argmax(axis=-1)
predicted_token = tokenizer.decode(masked_token)
print(predicted_token)
高級用法
top_n = 5
vocab_size = model.config.vocab_size
logits = output['logits'][0][mask_index].tolist()
top_tokens = sorted(list(range(vocab_size)), key=lambda i:logits[i], reverse=True)[:top_n]
tokenizer.decode(top_tokens)