🚀 用於干擾項生成的T5-large模型
本倉庫包含一個針對干擾項生成任務進行微調的T5-large模型。該模型藉助T5的文本到文本框架以及自定義分隔符標記,通過給定的問題、上下文和正確答案,為多項選擇題生成三個合理的干擾項。
🚀 快速開始
你可以使用Hugging Face的Transformers管道按以下方式使用此模型:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "fares7elsadek/t5-large-distractor-generation"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
SEP_TOKEN = "<sep>"
def generate_distractors(question, context, correct, max_length=64):
input_text = f"{question}{SEP_TOKEN}{correct}{SEP_TOKEN}{context}"
inputs = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True)
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
distractors = [d.strip() for d in decoded.split(SEP_TOKEN)]
return distractors
question = "What is the capital of France?"
context = "France is a country in Western Europe known for its rich history and cultural heritage."
correct = "Paris"
print(generate_distractors(question, context, correct))
✨ 主要特性
- 基於T5的文本到文本框架,能夠根據給定的問題、上下文和正確答案,為多項選擇題生成三個合理的干擾項。
- 採用自定義分隔符標記,有效處理輸入和目標序列。
📦 安裝指南
暫未提及安裝相關內容,可參考Hugging Face的Transformers庫安裝說明進行安裝。
💻 使用示例
基礎用法
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "fares7elsadek/t5-large-distractor-generation"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
SEP_TOKEN = "<sep>"
def generate_distractors(question, context, correct, max_length=64):
input_text = f"{question}{SEP_TOKEN}{correct}{SEP_TOKEN}{context}"
inputs = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True)
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
distractors = [d.strip() for d in decoded.split(SEP_TOKEN)]
return distractors
question = "What is the capital of France?"
context = "France is a country in Western Europe known for its rich history and cultural heritage."
correct = "Paris"
print(generate_distractors(question, context, correct))
📚 詳細文檔
模型概述
此實現基於PyTorch Lightning構建,對預訓練的T5-base模型進行微調,以生成干擾項選項。模型接受一個經過格式化的單一輸入序列,該序列包含問題、上下文和正確答案,並通過自定義標記分隔,然後生成包含三個干擾項的目標序列。這種方法在多項選擇題生成任務中尤為有用。
數據處理
輸入構建
每個輸入樣本是一個具有以下格式的單一字符串:
question {SEP_TOKEN} correct {SEP_TOKEN} context
- question:問題文本。
- context:上下文段落。
- correct:正確答案。
- SEP_TOKEN:添加到分詞器中的特殊標記,用於分隔不同字段。
目標構建
每個目標樣本的構建方式如下:
incorrect1 {SEP_TOKEN} incorrect2 {SEP_TOKEN} incorrect3
這種格式允許模型一次性生成三個干擾項。
訓練詳情
屬性 |
詳情 |
框架 |
PyTorch Lightning |
基礎模型 |
T5-base |
優化器 |
採用線性調度的Adam優化器(使用預熱調度器) |
批量大小 |
32 |
訓練輪數 |
5 |
學習率 |
2e-5 |
分詞處理 |
輸入:最大長度為512個標記;目標:最大長度為64個標記 |
特殊標記 |
自定義的SEP_TOKEN 被添加到分詞器中,用於分隔輸入和目標序列的不同部分 |
評估指標
模型使用每個生成干擾項的BLEU分數進行評估。以下是在測試集上獲得的BLEU分數:
干擾項 |
BLEU-1 |
BLEU-2 |
BLEU-3 |
BLEU-4 |
干擾項1 |
32.29 |
23.85 |
19.86 |
17.53 |
干擾項2 |
26.70 |
17.76 |
14.01 |
11.77 |
干擾項3 |
23.63 |
14.89 |
11.29 |
9.41 |
這些分數表明,與參考干擾項相比,該模型能夠生成具有較高n元語法重疊的干擾項。 |
|
|
|
|
🔧 技術細節
- 藉助T5的文本到文本框架,結合自定義分隔符標記,實現高效的干擾項生成。
- 採用PyTorch Lightning框架進行模型訓練,利用Adam優化器和線性調度策略。
- 對輸入和目標序列進行合理的分詞處理,確保模型能夠有效學習和生成干擾項。
📄 許可證
本項目採用MIT許可證。