🚀 roberta_qa_japanese
该模型是用于日语(抽取型)问答的模型,能基于给定文本准确抽取答案,解决日语问答需求
这个模型是 rinna/japanese-roberta-base(由rinna株式会社提供的预训练RoBERTa模型)的微调版本,专为抽取式问答任务而训练。
该模型在Skelter Labs提供的 JaQuAD 数据集上进行了微调,该数据集的数据来自日本维基百科文章,并经过人工标注。
🚀 快速开始
使用专用管道运行
from transformers import pipeline
model_name = "tsmatz/roberta_qa_japanese"
qa_pipeline = pipeline(
"question-answering",
model=model_name,
tokenizer=model_name)
result = qa_pipeline(
question = "決勝トーナメントで日本に勝ったのはどこでしたか。",
context = "日本は予選リーグで強豪のドイツとスペインに勝って決勝トーナメントに進んだが、クロアチアと対戦して敗れた。",
align_to_words = False,
)
print(result)
手动前向传播运行
import torch
import numpy as np
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
model_name = "tsmatz/roberta_qa_japanese"
model = (AutoModelForQuestionAnswering
.from_pretrained(model_name))
tokenizer = AutoTokenizer.from_pretrained(model_name)
def inference_answer(question, context):
question = question
context = context
test_feature = tokenizer(
question,
context,
max_length=318,
)
with torch.no_grad():
outputs = model(torch.tensor([test_feature["input_ids"]]))
start_logits = outputs.start_logits.cpu().numpy()
end_logits = outputs.end_logits.cpu().numpy()
answer_ids = test_feature["input_ids"][np.argmax(start_logits):np.argmax(end_logits)+1]
return "".join(tokenizer.batch_decode(answer_ids))
question = "決勝トーナメントで日本に勝ったのはどこでしたか。"
context = "日本は予選リーグで強豪のドイツとスペインに勝って決勝トーナメントに進んだが、クロアチアと対戦して敗れた。"
answer_pred = inference_answer(question, context)
print(answer_pred)
📚 详细文档
训练过程
你可以从 这里 下载微调的源代码。
训练超参数
训练期间使用了以下超参数:
- 学习率:7e-05
- 训练批次大小:2
- 评估批次大小:1
- 随机种子:42
- 梯度累积步数:16
- 总训练批次大小:32
- 优化器:Adam(β=(0.9, 0.999),ε=1e-08)
- 学习率调度器类型:线性
- 学习率调度器热身步数:100
- 训练轮数:3
训练结果
训练损失 |
轮数 |
步数 |
验证损失 |
2.1293 |
0.13 |
150 |
1.0311 |
1.1965 |
0.26 |
300 |
0.6723 |
1.022 |
0.39 |
450 |
0.4838 |
0.9594 |
0.53 |
600 |
0.5174 |
0.9187 |
0.66 |
750 |
0.4671 |
0.8229 |
0.79 |
900 |
0.4650 |
0.71 |
0.92 |
1050 |
0.2648 |
0.5436 |
1.05 |
1200 |
0.2665 |
0.5045 |
1.19 |
1350 |
0.2686 |
0.5025 |
1.32 |
1500 |
0.2082 |
0.5213 |
1.45 |
1650 |
0.1715 |
0.4648 |
1.58 |
1800 |
0.1563 |
0.4698 |
1.71 |
1950 |
0.1488 |
0.4823 |
1.84 |
2100 |
0.1050 |
0.4482 |
1.97 |
2250 |
0.0821 |
0.2755 |
2.11 |
2400 |
0.0898 |
0.2834 |
2.24 |
2550 |
0.0964 |
0.2525 |
2.37 |
2700 |
0.0533 |
0.2606 |
2.5 |
2850 |
0.0561 |
0.2467 |
2.63 |
3000 |
0.0601 |
0.2799 |
2.77 |
3150 |
0.0562 |
0.2497 |
2.9 |
3300 |
0.0516 |
框架版本
- Transformers 4.23.1
- Pytorch 1.12.1+cu102
- Datasets 2.6.1
- Tokenizers 0.13.1
📄 许可证
本项目采用MIT许可证。