🚀 BART-Squad2
BART-Squad2 是一个用于抽取式(基于文本片段)问答任务的模型,它在Squad 2.0数据集上进行训练,能够有效从文本中提取答案,为问答系统提供了强大的支持。
🚀 快速开始
本地运行问答示例
以下是在本地运行问答的快速方法:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
tokenizer = AutoTokenizer.from_pretrained("Primer/bart-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("Primer/bart-squad2")
model.to('cuda'); model.eval()
def answer(question, text):
seq = '<s>' + question + ' </s> </s> ' + text + ' </s>'
tokens = tokenizer.encode_plus(seq, return_tensors='pt', padding='max_length', max_length=1024)
input_ids = tokens['input_ids'].to('cuda')
attention_mask = tokens['attention_mask'].to('cuda')
start, end, _ = model(input_ids, attention_mask=attention_mask)
start_idx = int(start.argmax().int())
end_idx = int(end.argmax().int())
print(tokenizer.decode(input_ids[0, start_idx:end_idx]).strip())
>>> question = "Where does Tom live?"
>>> context = "Tom is an engineer in San Francisco."
>>> answer(question, context)
San Francisco
注意:如果在CPU上运行,去掉 .to('cuda')
相关代码即可。
✨ 主要特性
- 高准确率:在Squad 2.0数据集上训练,F1分数达到87.4。
- 抽取式问答:适用于抽取式(基于文本片段)的问答任务。
💻 使用示例
基础用法
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
tokenizer = AutoTokenizer.from_pretrained("Primer/bart-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("Primer/bart-squad2")
model.to('cuda'); model.eval()
def answer(question, text):
seq = '<s>' + question + ' </s> </s> ' + text + ' </s>'
tokens = tokenizer.encode_plus(seq, return_tensors='pt', padding='max_length', max_length=1024)
input_ids = tokens['input_ids'].to('cuda')
attention_mask = tokens['attention_mask'].to('cuda')
start, end, _ = model(input_ids, attention_mask=attention_mask)
start_idx = int(start.argmax().int())
end_idx = int(end.argmax().int())
print(tokenizer.decode(input_ids[0, start_idx:end_idx]).strip())
question = "Where does Tom live?"
context = "Tom is an engineer in San Francisco."
answer(question, context)
🔧 技术细节
训练参数
使用 run_squad.py
脚本进行训练,具体参数如下:
参数 |
值 |
批量大小 |
8 |
最大序列长度 |
1024 |
学习率 |
1e-5 |
训练轮数 |
2 |
模型修改
训练过程中对模型进行了修改,冻结了共享参数和编码器嵌入层。
📄 许可证
文档未提及相关许可证信息。
⚠️ 重要提示
很遗憾,Huggingface的自动推理API无法运行此模型。如果你尝试通过上方输入框运行模型并遇到报错,不必气馁!
💡 使用建议
模型大小为1.6G,在使用时请确保有足够的存储空间和计算资源。