🚀 搜索查询重写生成模型
本模型专为搜索查询重写而设计,采用序列到序列架构生成改写后的查询。它借助强化学习框架,结合策略梯度算法进一步提升性能。通过奖励函数训练,能对关键词进行释义,使生成的查询更多样化。该模型可与基于BM25的稀疏检索等方法集成,提高搜索中的文档召回率。
🚀 快速开始
若要使用此模型,可通过采样并设置重复惩罚来生成多样化的样本。以下是示例代码:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
MODEL_ID = "prhegde/t5-query-reformulation-RL"
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
model.eval()
input_sequence = "how to bake great cookie"
input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids
print(f'Input: {input_sequence}')
nsent = 4
with torch.no_grad():
for i in range(nsent):
output = model.generate(input_ids, max_length=35, num_beams=1, do_sample=True, repetition_penalty=1.8)
target_sequence = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Target: {target_sequence}')
✨ 主要特性
- 特定用途设计:专为搜索查询重写而打造,适用于多种搜索场景,如网页搜索、电商搜索等。
- 强化学习优化:运用强化学习框架和策略梯度算法,提升模型生成多样化且相关查询的能力。
- 可集成性:能与稀疏检索方法集成,提高搜索中的文档召回率。
📦 安装指南
文档未提及具体安装步骤,可参考模型仓库中的说明进行安装。
💻 使用示例
基础用法
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
MODEL_ID = "prhegde/t5-query-reformulation-RL"
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
model.eval()
input_sequence = "how to bake great cookie"
input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids
print(f'Input: {input_sequence}')
nsent = 4
with torch.no_grad():
for i in range(nsent):
output = model.generate(input_ids, max_length=35, num_beams=1, do_sample=True, repetition_penalty=1.8)
target_sequence = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Target: {target_sequence}')
高级用法
文档未提及高级用法相关代码,可根据实际需求调整模型的生成参数,如max_length
、num_beams
、do_sample
、repetition_penalty
等,以获得不同的生成效果。
📚 详细文档
预期用例
- 搜索查询重写:适用于网页搜索、电商搜索等场景,对查询进行改写以提高搜索效果。
- 虚拟助手和聊天机器人:帮助生成更自然、多样化的回复。
- 信息检索:提升信息检索的准确性和召回率。
模型描述
训练过程
- 训练过程从用Google的 T5-base模型 初始化序列到序列模型开始。
- 首先,使用 MS-MARCO查询对数据集对模型进行有监督训练。
- 随后,使用强化学习(RL)框架对模型进行微调,以增强其生成多样化且相关查询的能力。
- 采用策略梯度方法对模型进行微调。对于给定的输入查询,从模型中采样一组轨迹(改写后的查询)并计算奖励。应用策略梯度算法更新模型。
- 启发式地计算奖励以增强模型的释义能力。不过,这些奖励可根据需要用其他特定领域或特定目标的奖励函数替代。
更多详细信息请参考 此处。
模型来源
- 仓库:https://github.com/PraveenSH/RL-Query-Reformulation
🔧 技术细节
该模型采用序列到序列架构,结合强化学习框架和策略梯度算法进行训练。通过奖励函数引导模型生成多样化的查询,提高搜索中的文档召回率。训练过程包括初始化模型、有监督训练和强化学习微调等步骤。
📄 许可证
本模型采用Apache-2.0许可证。
信息表格
属性 |
详情 |
模型类型 |
生成式模型,用于搜索查询重写 |
训练数据 |
MS-MARCO查询对数据集 |
许可证 |
Apache-2.0 |
仓库地址 |
https://github.com/PraveenSH/RL-Query-Reformulation |