🚀 doc2query/all-with_prefix-t5-base-v1
这是一个基于 T5 的 doc2query 模型(也被称为 docT5query)。该模型可用于解决文档搜索中的词汇差距问题,并生成特定领域的训练数据,助力训练强大的密集嵌入模型。
🚀 快速开始
功能用途
此模型可用于以下两个主要方面:
- 文档扩展:为段落生成 20 - 40 个查询,并将段落和生成的查询索引到标准的 BM25 索引(如 Elasticsearch、OpenSearch 或 Lucene)中。生成的查询有助于缩小词汇搜索的词汇差距,因为这些查询包含同义词。此外,它还会对单词重新加权,即使重要单词在段落中很少出现,也会赋予更高的权重。在我们的 BEIR 论文中,我们证明了 BM25 + docT5query 是一个强大的搜索引擎。在 BEIR 仓库 中,我们有一个如何将 docT5query 与 Pyserini 结合使用的示例。
- 特定领域训练数据生成:可用于生成训练数据以学习嵌入模型。在 SBERT.net 上,我们有一个示例,展示了如何使用该模型为给定的未标记文本集合生成(查询,文本)对。这些对可用于训练强大的密集嵌入模型。
代码示例
from transformers import T5Tokenizer, T5ForConditionalGeneration
model_name = 'doc2query/all-with_prefix-t5-base-v1'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
prefix = "answer2question"
text = "Python is an interpreted, high-level and general-purpose programming language. Python's design philosophy emphasizes code readability with its notable use of significant whitespace. Its language constructs and object-oriented approach aim to help programmers write clear, logical code for small and large-scale projects."
text = prefix+": "+text
input_ids = tokenizer.encode(text, max_length=384, truncation=True, return_tensors='pt')
outputs = model.generate(
input_ids=input_ids,
max_length=64,
do_sample=True,
top_p=0.95,
num_return_sequences=5)
print("Text:")
print(text)
print("\nGenerated Queries:")
for i in range(len(outputs)):
query = tokenizer.decode(outputs[i], skip_special_tokens=True)
print(f'{i + 1}: {query}')
⚠️ 重要提示
model.generate()
是非确定性的,每次运行时会产生不同的查询结果。
📦 安装指南
本部分文档未提及安装步骤,若有需要,请参考相关依赖库(如 transformers
)的官方安装说明。
💻 使用示例
基础用法
上述代码示例展示了如何使用该模型生成查询。以下是代码的基本流程:
- 导入所需的
T5Tokenizer
和 T5ForConditionalGeneration
。
- 加载预训练的模型和分词器。
- 定义前缀和输入文本。
- 对输入文本进行编码。
- 使用模型生成查询。
- 解码并打印生成的查询。
高级用法
你可以根据不同的前缀生成不同类型的输出,具体前缀和输出类型如下:
前缀 |
输出类型 |
answer2question |
从答案生成问题 |
review2title |
从评论生成标题 |
abstract2title |
从摘要生成标题 |
text2query |
从文本生成查询 |
📚 详细文档
训练信息
此模型在 google/t5-v1_1-base 的基础上进行了 575k 个训练步骤的微调。训练脚本可在本仓库的 train_script.py
中找到。
输入文本被截断为 384 个词块,输出文本最多生成 64 个词块。
该模型在大量数据集上进行训练,确切的数据集名称和权重可在本仓库的 data_config.json
中找到。大多数数据集可在 https://huggingface.co/sentence-transformers 上获取。
数据集包括但不限于:
- 来自 Reddit 的(标题,正文)对。
- 来自 StackExchange 和 Yahoo Answers! 的(标题,正文)对和(标题,答案)对。
- 来自亚马逊评论的(标题,评论)对。
- 来自 MS MARCO、NQ 和 GooAQ 的(查询,段落)对。
- 来自 Quora 和 WikiAnswers 的(问题,重复问题)对。
- 来自 S2ORC 的(标题,摘要)对。
前缀说明
该模型在训练时 使用了前缀:你需要在文本开头添加特定的索引,以定义你希望接收的输出文本类型。根据前缀的不同,输出也会不同。
例如,上述关于 Python 的文本根据不同前缀会产生以下输出:
前缀 |
输出 |
answer2question |
为什么我应该在我的业务中使用 Python?;Python 和 .NET 有什么区别?;Python 的设计理念是什么? |
review2title |
Python:一种强大且有用的语言;一种新的、改进的编程语言;面向对象、实用且易访问 |
abstract2title |
Python:软件开发平台;Python X 研究指南:编程的概念方法;Python:语言与方法 |
text2query |
Python 是一种低级语言吗?;Python 的主要思想是什么?;Python 是一种编程语言吗? |
所有可用的前缀如下:
- text2reddit
- question2title
- answer2question
- abstract2title
- review2title
- news2title
- text2query
- question2question
不同前缀的数据集和权重可在本仓库的 data_config.json
中找到。
📄 许可证
本项目采用 Apache-2.0 许可证。