🚀 基於T5-base的干擾項生成模型
本倉庫包含一個經過微調的 T5-base 模型,用於生成選擇題的干擾項。該模型利用T5的文本到文本框架和自定義分隔符,根據給定的問題、上下文和正確答案,生成三個合理的干擾項。
🚀 快速開始
你可以使用Hugging Face的Transformers管道來使用這個模型,示例代碼如下:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "fares7elsadek/t5-base-distractor-generation"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
SEP_TOKEN = "<sep>"
def generate_distractors(question, context, correct, max_length=64):
input_text = f"{question} {SEP_TOKEN} {context} {SEP_TOKEN} {correct}"
inputs = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True)
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
distractors = [d.strip() for d in decoded.split(SEP_TOKEN)]
return distractors
question = "What is the capital of France?"
context = "France is a country in Western Europe known for its rich history and cultural heritage."
correct = "Paris"
print(generate_distractors(question, context, correct))
✨ 主要特性
- 利用T5的文本到文本框架和自定義分隔符,可根據給定的問題、上下文和正確答案,生成三個合理的干擾項。
- 採用單輸入序列格式,包含問題、上下文和正確答案,通過自定義分隔符分隔,便於模型處理。
- 能夠在一次推理中生成三個干擾項,提高了生成效率。
📦 安裝指南
暫未提及安裝相關內容,可參考Hugging Face的Transformers庫的安裝方式。
💻 使用示例
基礎用法
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "fares7elsadek/t5-base-distractor-generation"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
SEP_TOKEN = "<sep>"
def generate_distractors(question, context, correct, max_length=64):
input_text = f"{question} {SEP_TOKEN} {context} {SEP_TOKEN} {correct}"
inputs = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True)
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
distractors = [d.strip() for d in decoded.split(SEP_TOKEN)]
return distractors
question = "What is the capital of France?"
context = "France is a country in Western Europe known for its rich history and cultural heritage."
correct = "Paris"
print(generate_distractors(question, context, correct))
📚 詳細文檔
模型概述
該模型基於 PyTorch Lightning 構建,對預訓練的 T5-base 模型進行微調,以生成干擾項。模型接受一個包含問題、上下文和正確答案的單輸入序列(通過自定義分隔符分隔),並生成一個包含三個干擾項的目標序列。這種方法在選擇題生成任務中特別有用。
數據處理
輸入構建
每個輸入樣本是一個字符串,格式如下:
question {SEP_TOKEN} context {SEP_TOKEN} correct
- question:問題文本。
- context:上下文段落。
- correct:正確答案。
- SEP_TOKEN:添加到分詞器中的特殊標記,用於分隔不同字段。
目標構建
每個目標樣本的構建方式如下:
incorrect1 {SEP_TOKEN} incorrect2 {SEP_TOKEN} incorrect3
這種格式允許模型一次性生成三個干擾項。
訓練細節
- 框架:PyTorch Lightning
- 基礎模型:T5-base
- 優化器:使用線性調度的Adam優化器(帶有預熱調度器)
- 批量大小:32
- 訓練輪數:5
- 學習率:2e-5
- 分詞處理:
- 輸入:最大長度為512個標記
- 目標:最大長度為64個標記
- 特殊標記:自定義的
SEP_TOKEN
被添加到分詞器中,用於分隔輸入和目標序列的不同部分。
評估指標
模型使用BLEU分數對每個生成的干擾項進行評估。以下是在測試集上獲得的BLEU分數:
干擾項 |
BLEU-1 |
BLEU-2 |
BLEU-3 |
BLEU-4 |
干擾項1 |
29.59 |
21.55 |
17.86 |
15.75 |
干擾項2 |
25.21 |
16.81 |
13.00 |
10.78 |
干擾項3 |
23.99 |
15.78 |
12.35 |
10.52 |
這些分數表明,與參考干擾項相比,該模型能夠生成具有較高n-gram重疊率的干擾項。 |
|
|
|
|
🔧 技術細節
- 模型基於T5-base架構,通過微調適應干擾項生成任務。
- 使用自定義分隔符
SEP_TOKEN
來分隔輸入和目標序列的不同部分,便於模型理解和處理。
- 在訓練過程中,使用了線性調度的Adam優化器和預熱調度器,有助於模型的收斂和性能提升。
📄 許可證
本項目採用MIT許可證。