🚀 基于T5-base的干扰项生成模型
本仓库包含一个经过微调的 T5-base 模型,用于生成选择题的干扰项。该模型利用T5的文本到文本框架和自定义分隔符,根据给定的问题、上下文和正确答案,生成三个合理的干扰项。
🚀 快速开始
你可以使用Hugging Face的Transformers管道来使用这个模型,示例代码如下:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "fares7elsadek/t5-base-distractor-generation"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
SEP_TOKEN = "<sep>"
def generate_distractors(question, context, correct, max_length=64):
input_text = f"{question} {SEP_TOKEN} {context} {SEP_TOKEN} {correct}"
inputs = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True)
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
distractors = [d.strip() for d in decoded.split(SEP_TOKEN)]
return distractors
question = "What is the capital of France?"
context = "France is a country in Western Europe known for its rich history and cultural heritage."
correct = "Paris"
print(generate_distractors(question, context, correct))
✨ 主要特性
- 利用T5的文本到文本框架和自定义分隔符,可根据给定的问题、上下文和正确答案,生成三个合理的干扰项。
- 采用单输入序列格式,包含问题、上下文和正确答案,通过自定义分隔符分隔,便于模型处理。
- 能够在一次推理中生成三个干扰项,提高了生成效率。
📦 安装指南
暂未提及安装相关内容,可参考Hugging Face的Transformers库的安装方式。
💻 使用示例
基础用法
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name = "fares7elsadek/t5-base-distractor-generation"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
SEP_TOKEN = "<sep>"
def generate_distractors(question, context, correct, max_length=64):
input_text = f"{question} {SEP_TOKEN} {context} {SEP_TOKEN} {correct}"
inputs = tokenizer([input_text], return_tensors="pt", truncation=True, padding=True)
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=max_length
)
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
distractors = [d.strip() for d in decoded.split(SEP_TOKEN)]
return distractors
question = "What is the capital of France?"
context = "France is a country in Western Europe known for its rich history and cultural heritage."
correct = "Paris"
print(generate_distractors(question, context, correct))
📚 详细文档
模型概述
该模型基于 PyTorch Lightning 构建,对预训练的 T5-base 模型进行微调,以生成干扰项。模型接受一个包含问题、上下文和正确答案的单输入序列(通过自定义分隔符分隔),并生成一个包含三个干扰项的目标序列。这种方法在选择题生成任务中特别有用。
数据处理
输入构建
每个输入样本是一个字符串,格式如下:
question {SEP_TOKEN} context {SEP_TOKEN} correct
- question:问题文本。
- context:上下文段落。
- correct:正确答案。
- SEP_TOKEN:添加到分词器中的特殊标记,用于分隔不同字段。
目标构建
每个目标样本的构建方式如下:
incorrect1 {SEP_TOKEN} incorrect2 {SEP_TOKEN} incorrect3
这种格式允许模型一次性生成三个干扰项。
训练细节
- 框架:PyTorch Lightning
- 基础模型:T5-base
- 优化器:使用线性调度的Adam优化器(带有预热调度器)
- 批量大小:32
- 训练轮数:5
- 学习率:2e-5
- 分词处理:
- 输入:最大长度为512个标记
- 目标:最大长度为64个标记
- 特殊标记:自定义的
SEP_TOKEN
被添加到分词器中,用于分隔输入和目标序列的不同部分。
评估指标
模型使用BLEU分数对每个生成的干扰项进行评估。以下是在测试集上获得的BLEU分数:
干扰项 |
BLEU-1 |
BLEU-2 |
BLEU-3 |
BLEU-4 |
干扰项1 |
29.59 |
21.55 |
17.86 |
15.75 |
干扰项2 |
25.21 |
16.81 |
13.00 |
10.78 |
干扰项3 |
23.99 |
15.78 |
12.35 |
10.52 |
这些分数表明,与参考干扰项相比,该模型能够生成具有较高n-gram重叠率的干扰项。 |
|
|
|
|
🔧 技术细节
- 模型基于T5-base架构,通过微调适应干扰项生成任务。
- 使用自定义分隔符
SEP_TOKEN
来分隔输入和目标序列的不同部分,便于模型理解和处理。
- 在训练过程中,使用了线性调度的Adam优化器和预热调度器,有助于模型的收敛和性能提升。
📄 许可证
本项目采用MIT许可证。