🚀 搜索查詢重寫生成模型
本模型專為搜索查詢重寫而設計,採用序列到序列架構生成改寫後的查詢。它藉助強化學習框架,結合策略梯度算法進一步提升性能。通過獎勵函數訓練,能對關鍵詞進行釋義,使生成的查詢更多樣化。該模型可與基於BM25的稀疏檢索等方法集成,提高搜索中的文檔召回率。
🚀 快速開始
若要使用此模型,可通過採樣並設置重複懲罰來生成多樣化的樣本。以下是示例代碼:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
MODEL_ID = "prhegde/t5-query-reformulation-RL"
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
model.eval()
input_sequence = "how to bake great cookie"
input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids
print(f'Input: {input_sequence}')
nsent = 4
with torch.no_grad():
for i in range(nsent):
output = model.generate(input_ids, max_length=35, num_beams=1, do_sample=True, repetition_penalty=1.8)
target_sequence = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Target: {target_sequence}')
✨ 主要特性
- 特定用途設計:專為搜索查詢重寫而打造,適用於多種搜索場景,如網頁搜索、電商搜索等。
- 強化學習優化:運用強化學習框架和策略梯度算法,提升模型生成多樣化且相關查詢的能力。
- 可集成性:能與稀疏檢索方法集成,提高搜索中的文檔召回率。
📦 安裝指南
文檔未提及具體安裝步驟,可參考模型倉庫中的說明進行安裝。
💻 使用示例
基礎用法
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
MODEL_ID = "prhegde/t5-query-reformulation-RL"
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
model.eval()
input_sequence = "how to bake great cookie"
input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids
print(f'Input: {input_sequence}')
nsent = 4
with torch.no_grad():
for i in range(nsent):
output = model.generate(input_ids, max_length=35, num_beams=1, do_sample=True, repetition_penalty=1.8)
target_sequence = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Target: {target_sequence}')
高級用法
文檔未提及高級用法相關代碼,可根據實際需求調整模型的生成參數,如max_length
、num_beams
、do_sample
、repetition_penalty
等,以獲得不同的生成效果。
📚 詳細文檔
預期用例
- 搜索查詢重寫:適用於網頁搜索、電商搜索等場景,對查詢進行改寫以提高搜索效果。
- 虛擬助手和聊天機器人:幫助生成更自然、多樣化的回覆。
- 信息檢索:提升信息檢索的準確性和召回率。
模型描述
訓練過程
- 訓練過程從用Google的 T5-base模型 初始化序列到序列模型開始。
- 首先,使用 MS-MARCO查詢對數據集對模型進行有監督訓練。
- 隨後,使用強化學習(RL)框架對模型進行微調,以增強其生成多樣化且相關查詢的能力。
- 採用策略梯度方法對模型進行微調。對於給定的輸入查詢,從模型中採樣一組軌跡(改寫後的查詢)並計算獎勵。應用策略梯度算法更新模型。
- 啟發式地計算獎勵以增強模型的釋義能力。不過,這些獎勵可根據需要用其他特定領域或特定目標的獎勵函數替代。
更多詳細信息請參考 此處。
模型來源
- 倉庫:https://github.com/PraveenSH/RL-Query-Reformulation
🔧 技術細節
該模型採用序列到序列架構,結合強化學習框架和策略梯度算法進行訓練。通過獎勵函數引導模型生成多樣化的查詢,提高搜索中的文檔召回率。訓練過程包括初始化模型、有監督訓練和強化學習微調等步驟。
📄 許可證
本模型採用Apache-2.0許可證。
信息表格
屬性 |
詳情 |
模型類型 |
生成式模型,用於搜索查詢重寫 |
訓練數據 |
MS-MARCO查詢對數據集 |
許可證 |
Apache-2.0 |
倉庫地址 |
https://github.com/PraveenSH/RL-Query-Reformulation |