🚀 BART-Squad2
BART-Squad2 是一個用於抽取式(基於文本片段)問答任務的模型,它在Squad 2.0數據集上進行訓練,能夠有效從文本中提取答案,為問答系統提供了強大的支持。
🚀 快速開始
本地運行問答示例
以下是在本地運行問答的快速方法:
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
tokenizer = AutoTokenizer.from_pretrained("Primer/bart-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("Primer/bart-squad2")
model.to('cuda'); model.eval()
def answer(question, text):
seq = '<s>' + question + ' </s> </s> ' + text + ' </s>'
tokens = tokenizer.encode_plus(seq, return_tensors='pt', padding='max_length', max_length=1024)
input_ids = tokens['input_ids'].to('cuda')
attention_mask = tokens['attention_mask'].to('cuda')
start, end, _ = model(input_ids, attention_mask=attention_mask)
start_idx = int(start.argmax().int())
end_idx = int(end.argmax().int())
print(tokenizer.decode(input_ids[0, start_idx:end_idx]).strip())
>>> question = "Where does Tom live?"
>>> context = "Tom is an engineer in San Francisco."
>>> answer(question, context)
San Francisco
注意:如果在CPU上運行,去掉 .to('cuda')
相關代碼即可。
✨ 主要特性
- 高準確率:在Squad 2.0數據集上訓練,F1分數達到87.4。
- 抽取式問答:適用於抽取式(基於文本片段)的問答任務。
💻 使用示例
基礎用法
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
tokenizer = AutoTokenizer.from_pretrained("Primer/bart-squad2")
model = AutoModelForQuestionAnswering.from_pretrained("Primer/bart-squad2")
model.to('cuda'); model.eval()
def answer(question, text):
seq = '<s>' + question + ' </s> </s> ' + text + ' </s>'
tokens = tokenizer.encode_plus(seq, return_tensors='pt', padding='max_length', max_length=1024)
input_ids = tokens['input_ids'].to('cuda')
attention_mask = tokens['attention_mask'].to('cuda')
start, end, _ = model(input_ids, attention_mask=attention_mask)
start_idx = int(start.argmax().int())
end_idx = int(end.argmax().int())
print(tokenizer.decode(input_ids[0, start_idx:end_idx]).strip())
question = "Where does Tom live?"
context = "Tom is an engineer in San Francisco."
answer(question, context)
🔧 技術細節
訓練參數
使用 run_squad.py
腳本進行訓練,具體參數如下:
參數 |
值 |
批量大小 |
8 |
最大序列長度 |
1024 |
學習率 |
1e-5 |
訓練輪數 |
2 |
模型修改
訓練過程中對模型進行了修改,凍結了共享參數和編碼器嵌入層。
📄 許可證
文檔未提及相關許可證信息。
⚠️ 重要提示
很遺憾,Huggingface的自動推理API無法運行此模型。如果你嘗試通過上方輸入框運行模型並遇到報錯,不必氣餒!
💡 使用建議
模型大小為1.6G,在使用時請確保有足夠的存儲空間和計算資源。