🚀 XtremeDistilTransformers:用于蒸馏大规模神经网络
XtremeDistilTransformers是一个经过蒸馏的、与任务无关的Transformer模型。它利用任务迁移来学习一个小型通用模型,该模型可以应用于任意任务和语言,相关内容在论文XtremeDistilTransformers: Task Transfer for Task-agnostic Distillation中有详细阐述。
我们结合了论文XtremeDistil: Multi-stage Distillation for Massive Multilingual Models和MiniLM: Deep Self-Attention Distillation for Task-Agnostic Compression of Pre-Trained Transformers中的多任务蒸馏技术以及任务迁移方法,并提供了对应的Github代码。
这个l6 - h384检查点有6层,隐藏层大小为384,有12个注意力头,对应2200万个参数,相比BERT - base速度提升了5.3倍。
其他可用的检查点:xtremedistil - l6 - h256 - uncased和xtremedistil - l12 - h384 - uncased
下表展示了在GLUE开发集和SQuAD - v2上的实验结果。
模型 |
参数数量 |
加速倍数 |
MNLI |
QNLI |
QQP |
RTE |
SST |
MRPC |
SQUAD2 |
平均 |
BERT |
1.09亿 |
1倍 |
84.5 |
91.7 |
91.3 |
68.6 |
93.2 |
87.3 |
76.8 |
84.8 |
DistilBERT |
6600万 |
2倍 |
82.2 |
89.2 |
88.5 |
59.9 |
91.3 |
87.5 |
70.7 |
81.3 |
TinyBERT |
6600万 |
2倍 |
83.5 |
90.5 |
90.6 |
72.2 |
91.6 |
88.4 |
73.1 |
84.3 |
MiniLM |
6600万 |
2倍 |
84.0 |
91.0 |
91.0 |
71.5 |
92.0 |
88.4 |
76.4 |
84.9 |
MiniLM |
2200万 |
5.3倍 |
82.8 |
90.3 |
90.6 |
68.9 |
91.3 |
86.6 |
72.9 |
83.3 |
XtremeDistil - l6 - h256 |
1300万 |
8.7倍 |
83.9 |
89.5 |
90.6 |
80.1 |
91.2 |
90.0 |
74.1 |
85.6 |
XtremeDistil - l6 - h384 |
2200万 |
5.3倍 |
85.4 |
90.3 |
91.0 |
80.9 |
92.3 |
90.0 |
76.6 |
86.6 |
XtremeDistil - l12 - h384 |
3300万 |
2.7倍 |
87.2 |
91.9 |
91.3 |
85.6 |
93.1 |
90.4 |
80.2 |
88.5 |
测试环境为tensorflow 2.3.1, transformers 4.1.1, torch 1.6.0
如果您在工作中使用了这个检查点,请引用以下文献:
@misc{mukherjee2021xtremedistiltransformers,
title={XtremeDistilTransformers: Task Transfer for Task-agnostic Distillation},
author={Subhabrata Mukherjee and Ahmed Hassan Awadallah and Jianfeng Gao},
year={2021},
eprint={2106.04563},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
🚀 快速开始
XtremeDistilTransformers是一个经过蒸馏的、与任务无关的Transformer模型,它借助任务迁移学习小型通用模型,可应用于任意任务和语言。
✨ 主要特性
- 利用任务迁移和多任务蒸馏技术,学习小型通用模型。
- 有多种检查点可供选择,如l6 - h384、l6 - h256 - uncased、l12 - h384 - uncased等。
- 相比BERT - base有显著的速度提升,如l6 - h384检查点速度提升了5.3倍。
- 在GLUE开发集和SQuAD - v2上有较好的实验结果。
📚 详细文档
模型信息
这个l6 - h384检查点有6层,隐藏层大小为384,有12个注意力头,对应2200万个参数。
其他检查点
实验结果
模型 |
参数数量 |
加速倍数 |
MNLI |
QNLI |
QQP |
RTE |
SST |
MRPC |
SQUAD2 |
平均 |
BERT |
1.09亿 |
1倍 |
84.5 |
91.7 |
91.3 |
68.6 |
93.2 |
87.3 |
76.8 |
84.8 |
DistilBERT |
6600万 |
2倍 |
82.2 |
89.2 |
88.5 |
59.9 |
91.3 |
87.5 |
70.7 |
81.3 |
TinyBERT |
6600万 |
2倍 |
83.5 |
90.5 |
90.6 |
72.2 |
91.6 |
88.4 |
73.1 |
84.3 |
MiniLM |
6600万 |
2倍 |
84.0 |
91.0 |
91.0 |
71.5 |
92.0 |
88.4 |
76.4 |
84.9 |
MiniLM |
2200万 |
5.3倍 |
82.8 |
90.3 |
90.6 |
68.9 |
91.3 |
86.6 |
72.9 |
83.3 |
XtremeDistil - l6 - h256 |
1300万 |
8.7倍 |
83.9 |
89.5 |
90.6 |
80.1 |
91.2 |
90.0 |
74.1 |
85.6 |
XtremeDistil - l6 - h384 |
2200万 |
5.3倍 |
85.4 |
90.3 |
91.0 |
80.9 |
92.3 |
90.0 |
76.6 |
86.6 |
XtremeDistil - l12 - h384 |
3300万 |
2.7倍 |
87.2 |
91.9 |
91.3 |
85.6 |
93.1 |
90.4 |
80.2 |
88.5 |
测试环境
测试使用了tensorflow 2.3.1, transformers 4.1.1, torch 1.6.0
。
📄 许可证
本项目采用MIT许可证。