🚀 剪枝BERTモデル
このモデルは、BookCorpusデータセットを使用し、知識蒸留技術を用いて事前学習されています。BERTと同じアーキテクチャを持ちながら、隠れ層のサイズは256(BERTの隠れ層サイズの3分の1)で、アテンションヘッドは4つ(BERTと同じヘッドサイズ)です。モデルの重みは、bert-base-uncased
の重みを剪定して初期化され、複数の損失関数を使用した知識蒸留によって微調整されます。また、このモデルはbert-base-uncased
と同じトークナイザーを使用しています。
🚀 クイックスタート
モデルとトークナイザーの読み込み
from transformers import AutoModelForMaskedLM, BertTokenizer
model_name = "eli4s/prunedBert-L12-h256-A4-finetuned"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
文章の予測
import torch
sentence = "Let's have a [MASK]."
model.eval()
inputs = tokenizer([sentence], padding='longest', return_tensors='pt')
output = model(inputs['input_ids'], attention_mask=inputs['attention_mask'])
mask_index = inputs['input_ids'].tolist()[0].index(103)
masked_token = output['logits'][0][mask_index].argmax(axis=-1)
predicted_token = tokenizer.decode(masked_token)
print(predicted_token)
最も関連性の高い上位n個の結果を予測
top_n = 5
vocab_size = model.config.vocab_size
logits = output['logits'][0][mask_index].tolist()
top_tokens = sorted(list(range(vocab_size)), key=lambda i:logits[i], reverse=True)[:top_n]
tokenizer.decode(top_tokens)
✨ 主な機能
- 軽量アーキテクチャ:隠れ層のサイズが256と、BERTの隠れ層サイズの3分の1であり、計算リソースの必要量を削減します。
- 知識蒸留による微調整:複数の損失関数を使用した知識蒸留により、性能を維持しながらモデルを最適化します。
- 汎用トークナイザー:
bert-base-uncased
と同じトークナイザーを使用するため、既存のシステムに容易に統合できます。
🔧 技術詳細
このモデルの学習プロセスは、重みの初期化と知識蒸留による微調整の2つの主要なステップに分かれています。まず、bert-base-uncased
の重みを剪定することでモデルの重みを初期化します。次に、複数の損失関数を使用した知識蒸留によってモデルを微調整します。この方法により、モデルは比較的小さいパラメータ規模を維持しながら、良好な性能を達成することができます。
アーキテクチャに関しては、モデルはBERTと同じですが、隠れ層のサイズが256、アテンションヘッドの数が4つであり、計算効率と性能のバランスを取っています。
💻 使用例
基本的な使用法
from transformers import AutoModelForMaskedLM, BertTokenizer
model_name = "eli4s/prunedBert-L12-h256-A4-finetuned"
model = AutoModelForMaskedLM.from_pretrained(model_name)
tokenizer = BertTokenizer.from_pretrained(model_name)
import torch
sentence = "Let's have a [MASK]."
model.eval()
inputs = tokenizer([sentence], padding='longest', return_tensors='pt')
output = model(inputs['input_ids'], attention_mask=inputs['attention_mask'])
mask_index = inputs['input_ids'].tolist()[0].index(103)
masked_token = output['logits'][0][mask_index].argmax(axis=-1)
predicted_token = tokenizer.decode(masked_token)
print(predicted_token)
高度な使用法
top_n = 5
vocab_size = model.config.vocab_size
logits = output['logits'][0][mask_index].tolist()
top_tokens = sorted(list(range(vocab_size)), key=lambda i:logits[i], reverse=True)[:top_n]
tokenizer.decode(top_tokens)