🚀 XtremeDistilTransformers:用於蒸餾大規模神經網絡
XtremeDistilTransformers是一個經過蒸餾的、與任務無關的Transformer模型。它利用任務遷移來學習一個小型通用模型,該模型可以應用於任意任務和語言,相關內容在論文XtremeDistilTransformers: Task Transfer for Task-agnostic Distillation中有詳細闡述。
我們結合了論文XtremeDistil: Multi-stage Distillation for Massive Multilingual Models和MiniLM: Deep Self-Attention Distillation for Task-Agnostic Compression of Pre-Trained Transformers中的多任務蒸餾技術以及任務遷移方法,並提供了對應的Github代碼。
這個l6 - h384檢查點有6層,隱藏層大小為384,有12個注意力頭,對應2200萬個參數,相比BERT - base速度提升了5.3倍。
其他可用的檢查點:xtremedistil - l6 - h256 - uncased和xtremedistil - l12 - h384 - uncased
下表展示了在GLUE開發集和SQuAD - v2上的實驗結果。
模型 |
參數數量 |
加速倍數 |
MNLI |
QNLI |
QQP |
RTE |
SST |
MRPC |
SQUAD2 |
平均 |
BERT |
1.09億 |
1倍 |
84.5 |
91.7 |
91.3 |
68.6 |
93.2 |
87.3 |
76.8 |
84.8 |
DistilBERT |
6600萬 |
2倍 |
82.2 |
89.2 |
88.5 |
59.9 |
91.3 |
87.5 |
70.7 |
81.3 |
TinyBERT |
6600萬 |
2倍 |
83.5 |
90.5 |
90.6 |
72.2 |
91.6 |
88.4 |
73.1 |
84.3 |
MiniLM |
6600萬 |
2倍 |
84.0 |
91.0 |
91.0 |
71.5 |
92.0 |
88.4 |
76.4 |
84.9 |
MiniLM |
2200萬 |
5.3倍 |
82.8 |
90.3 |
90.6 |
68.9 |
91.3 |
86.6 |
72.9 |
83.3 |
XtremeDistil - l6 - h256 |
1300萬 |
8.7倍 |
83.9 |
89.5 |
90.6 |
80.1 |
91.2 |
90.0 |
74.1 |
85.6 |
XtremeDistil - l6 - h384 |
2200萬 |
5.3倍 |
85.4 |
90.3 |
91.0 |
80.9 |
92.3 |
90.0 |
76.6 |
86.6 |
XtremeDistil - l12 - h384 |
3300萬 |
2.7倍 |
87.2 |
91.9 |
91.3 |
85.6 |
93.1 |
90.4 |
80.2 |
88.5 |
測試環境為tensorflow 2.3.1, transformers 4.1.1, torch 1.6.0
如果您在工作中使用了這個檢查點,請引用以下文獻:
@misc{mukherjee2021xtremedistiltransformers,
title={XtremeDistilTransformers: Task Transfer for Task-agnostic Distillation},
author={Subhabrata Mukherjee and Ahmed Hassan Awadallah and Jianfeng Gao},
year={2021},
eprint={2106.04563},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
🚀 快速開始
XtremeDistilTransformers是一個經過蒸餾的、與任務無關的Transformer模型,它藉助任務遷移學習小型通用模型,可應用於任意任務和語言。
✨ 主要特性
- 利用任務遷移和多任務蒸餾技術,學習小型通用模型。
- 有多種檢查點可供選擇,如l6 - h384、l6 - h256 - uncased、l12 - h384 - uncased等。
- 相比BERT - base有顯著的速度提升,如l6 - h384檢查點速度提升了5.3倍。
- 在GLUE開發集和SQuAD - v2上有較好的實驗結果。
📚 詳細文檔
模型信息
這個l6 - h384檢查點有6層,隱藏層大小為384,有12個注意力頭,對應2200萬個參數。
其他檢查點
實驗結果
模型 |
參數數量 |
加速倍數 |
MNLI |
QNLI |
QQP |
RTE |
SST |
MRPC |
SQUAD2 |
平均 |
BERT |
1.09億 |
1倍 |
84.5 |
91.7 |
91.3 |
68.6 |
93.2 |
87.3 |
76.8 |
84.8 |
DistilBERT |
6600萬 |
2倍 |
82.2 |
89.2 |
88.5 |
59.9 |
91.3 |
87.5 |
70.7 |
81.3 |
TinyBERT |
6600萬 |
2倍 |
83.5 |
90.5 |
90.6 |
72.2 |
91.6 |
88.4 |
73.1 |
84.3 |
MiniLM |
6600萬 |
2倍 |
84.0 |
91.0 |
91.0 |
71.5 |
92.0 |
88.4 |
76.4 |
84.9 |
MiniLM |
2200萬 |
5.3倍 |
82.8 |
90.3 |
90.6 |
68.9 |
91.3 |
86.6 |
72.9 |
83.3 |
XtremeDistil - l6 - h256 |
1300萬 |
8.7倍 |
83.9 |
89.5 |
90.6 |
80.1 |
91.2 |
90.0 |
74.1 |
85.6 |
XtremeDistil - l6 - h384 |
2200萬 |
5.3倍 |
85.4 |
90.3 |
91.0 |
80.9 |
92.3 |
90.0 |
76.6 |
86.6 |
XtremeDistil - l12 - h384 |
3300萬 |
2.7倍 |
87.2 |
91.9 |
91.3 |
85.6 |
93.1 |
90.4 |
80.2 |
88.5 |
測試環境
測試使用了tensorflow 2.3.1, transformers 4.1.1, torch 1.6.0
。
📄 許可證
本項目採用MIT許可證。