🚀 doc2query/msmarco-t5-base-v1
这是一个基于T5的doc2query模型(也称为docT5query)。该模型可用于解决文档搜索中的词汇差距问题,以及生成特定领域的训练数据,助力训练强大的密集嵌入模型。
🚀 快速开始
本模型可用于以下两个主要场景:
- 文档扩展:为段落生成20 - 40个查询,并将段落和生成的查询索引到标准的BM25索引(如Elasticsearch、OpenSearch或Lucene)中。生成的查询有助于缩小词汇搜索的词汇差距,因为生成的查询包含同义词。此外,它会重新加权单词,即使重要单词在段落中很少出现,也会赋予更高的权重。在我们的BEIR论文中,我们证明了BM25 + docT5query是一个强大的搜索引擎。在BEIR仓库中,我们有一个如何使用docT5query与Pyserini的示例。
- 特定领域训练数据生成:可用于生成训练数据以学习嵌入模型。在SBERT.net上,我们有一个如何使用该模型为给定的未标记文本集合生成(查询,文本)对的示例。这些对可用于训练强大的密集嵌入模型。
💻 使用示例
基础用法
from transformers import T5Tokenizer, T5ForConditionalGeneration
model_name = 'doc2query/msmarco-t5-base-v1'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
text = "Python is an interpreted, high-level and general-purpose programming language. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Its language constructs and object-oriented approach aim to help programmers write clear, logical code for small and large-scale projects."
input_ids = tokenizer.encode(text, max_length=320, truncation=True, return_tensors='pt')
outputs = model.generate(
input_ids=input_ids,
max_length=64,
do_sample=True,
top_p=0.95,
num_return_sequences=5)
print("Text:")
print(text)
print("\nGenerated Queries:")
for i in range(len(outputs)):
query = tokenizer.decode(outputs[i], skip_special_tokens=True)
print(f'{i + 1}: {query}')
⚠️ 重要提示
model.generate()
是非确定性的,每次运行时会产生不同的查询。
🔧 技术细节
本模型在 google/t5-v1_1-base 的基础上进行了微调,训练了31000个步骤(在来自MS MARCO的500000个训练对数据上约4个轮次)。训练脚本可在本仓库的 train_script.py
中查看。
输入文本被截断为320个词块,输出文本最多生成64个词块。该模型使用了来自 MS MARCO Passage-Ranking数据集 的(查询,段落)对进行训练。
📄 许可证
本项目采用Apache 2.0许可证。