đ RL-Query-Reformulation
A generative model for search query rewriting, leveraging a sequence-to-sequence architecture and reinforcement learning to enhance document recall in search.
đ Quick Start
This is a generative model tailored for search query rewriting. It uses a sequence-to-sequence architecture and a reinforcement learning framework to generate diverse and relevant reformulated queries. It can be integrated with sparse retrieval methods to improve document recall in search.
⨠Features
- Sequence-to-Sequence Architecture: Generates reformulated queries using a sequence-to-sequence model.
- Reinforcement Learning: Employs a policy gradient algorithm to enhance the model's performance.
- Diverse Query Generation: Uses reward functions to diversify the generated queries by paraphrasing keywords.
- Integration with Sparse Retrieval: Can be integrated with sparse retrieval methods, such as bm25-based retrieval, to improve document recall.
đĻ Installation
The model can be used directly through the transformers
library. Make sure you have torch
and transformers
installed in your Python environment. You can install them using the following commands:
pip install torch
pip install transformers
đģ Usage Examples
Basic Usage
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
MODEL_ID = "prhegde/t5-query-reformulation-RL"
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
model.eval()
input_sequence = "how to bake great cookie"
input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids
print(f'Input: {input_sequence}')
nsent = 4
with torch.no_grad():
for i in range(nsent):
output = model.generate(input_ids, max_length=35, num_beams=1, do_sample=True, repetition_penalty=1.8)
target_sequence = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Target: {target_sequence}')
Advanced Usage
For more advanced usage, you can adjust the generation parameters according to your specific needs. For example, you can change the max_length
, num_beams
, do_sample
, and repetition_penalty
parameters to control the generation process.
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
MODEL_ID = "prhegde/t5-query-reformulation-RL"
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID)
model.eval()
input_sequence = "how to bake great cookie"
input_ids = tokenizer(input_sequence, return_tensors="pt").input_ids
print(f'Input: {input_sequence}')
max_length = 50
num_beams = 3
do_sample = True
repetition_penalty = 2.0
nsent = 4
with torch.no_grad():
for i in range(nsent):
output = model.generate(input_ids, max_length=max_length, num_beams=num_beams, do_sample=do_sample, repetition_penalty=repetition_penalty)
target_sequence = tokenizer.decode(output[0], skip_special_tokens=True)
print(f'Target: {target_sequence}')
đ Documentation
Intended use cases
- Query rewriting for search (web, e-commerce)
- Virtual assistants and chatbots
- Information retrieval
Model Description
Training Procedure
- Initialize the sequence-to-sequence model with Google's T5-base model .
- Conduct supervised training using the MS-MARCO query pairs dataset.
- Fine-tune the model using a reinforcement learning (RL) framework to enhance its ability to generate diverse and relevant queries.
- Apply a policy gradient approach to fine-tune the model. For a given input query, sample a set of trajectories (reformulated queries) from the model and compute the reward. Then, use the policy gradient algorithm to update the model.
- Compute rewards heuristically to enhance the model's paraphrasing capability. These rewards can be replaced with other domain-specific or goal-specific reward functions as needed.
Refer here for more details.
Model Sources
- Repository: https://github.com/PraveenSH/RL-Query-Reformulation
đ§ Technical Details
This model uses a sequence-to-sequence architecture and a reinforcement learning framework to generate reformulated queries. It employs a policy gradient algorithm to fine-tune the model and heuristically computes rewards to enhance the model's paraphrasing capability. The model is trained on the MS-MARCO query pairs dataset and can be integrated with sparse retrieval methods to improve document recall in search.
đ License
This model is licensed under the Apache-2.0 license.