🚀 小型ELECTRA模型⚡ + SQuAD v2问答数据集❓
本项目基于 小型ELECTRA判别器,在 SQuAD v2.0数据集 上进行微调,以用于问答(Q&A)下游任务。该模型能够有效处理问答任务,在问答准确性和效率上有不错的表现。
🚀 快速开始
你可以使用以下代码快速调用该模型进行问答任务:
from transformers import pipeline
QnA_pipeline = pipeline('question-answering', model='mrm8488/electra-base-finetuned-squadv2')
QnA_pipeline({
'context': 'A new strain of flu that has the potential to become a pandemic has been identified in China by scientists.',
'question': 'What has been discovered by scientists from China ?'
})
{'answer': 'A new strain of flu', 'end': 19, 'score': 0.8650811568752914, 'start': 0}
✨ 主要特性
- 高效预训练:ELECTRA是一种用于自监督语言表征学习的新方法,能以相对较少的计算资源预训练Transformer网络。
- 处理复杂问答:SQuAD2.0数据集包含大量可回答和不可回答的问题,模型需要判断何时无法从段落中找到答案并放弃回答。
📦 安装指南
模型在Tesla P100 GPU和25GB内存的环境下,使用以下命令进行训练:
python transformers/examples/question-answering/run_squad.py \
--model_type electra \
--model_name_or_path 'google/electra-small-discriminator' \
--do_eval \
--do_train \
--do_lower_case \
--train_file '/content/dataset/train-v2.0.json' \
--predict_file '/content/dataset/dev-v2.0.json' \
--per_gpu_train_batch_size 16 \
--learning_rate 3e-5 \
--num_train_epochs 10 \
--max_seq_length 384 \
--doc_stride 128 \
--output_dir '/content/output' \
--overwrite_output_dir \
--save_steps 1000 \
--version_2_with_negative
📚 详细文档
下游任务详情 - 模型 🧠
ELECTRA 是一种用于自监督语言表征学习的新方法,它可以用相对较少的计算资源来预训练Transformer网络。ELECTRA模型通过训练来区分“真实”输入标记和由另一个神经网络生成的“虚假”输入标记,类似于 生成对抗网络(GAN) 中的判别器。在小规模情况下,即使在单个GPU上进行训练,ELECTRA也能取得不错的效果。在大规模情况下,ELECTRA在 SQuAD 2.0 数据集上取得了最先进的结果。
下游任务详情 - 数据集 📚
SQuAD2.0 将SQuAD1.1中的100,000个问题与超过50,000个由众包工作者对抗性编写的无法回答的问题相结合,这些问题看起来与可回答的问题相似。要在SQuAD2.0上表现良好,系统不仅要在可能的情况下回答问题,还要判断段落中何时不支持答案并放弃回答。
📦 测试集结果 🧾
指标 |
值 |
精确匹配率(EM) |
69.71 |
F1分数 |
73.44 |
模型大小 |
50 MB |
以下是详细的JSON格式测试结果:
{
'exact': 69.71279373368147,
'f1': 73.4439546123672,
'total': 11873,
'HasAns_exact': 69.92240215924427,
'HasAns_f1': 77.39542393937836,
'HasAns_total': 5928,
'NoAns_exact': 69.50378469301934,
'NoAns_f1': 69.50378469301934,
'NoAns_total': 5945,
'best_exact': 69.71279373368147,
'best_exact_thresh': 0.0,
'best_f1': 73.44395461236732,
'best_f1_thresh': 0.0
}
📄 许可证
本项目采用Apache-2.0许可证。
由 Manuel Romero/@mrm8488 创建 | 领英
于西班牙用心打造 ♥