🚀 doc2query/msmarco-portuguese-mt5-base-v1
这是一个基于mT5的doc2query模型(也被称为docT5query)。该模型可用于解决文档扩展和特定领域训练数据生成的问题,为信息检索和模型训练提供了强大的支持。
🚀 快速开始
此模型可用于以下两个主要场景:
- 文档扩展:为段落生成20 - 40个查询,并将段落和生成的查询索引到标准的BM25索引(如Elasticsearch、OpenSearch或Lucene)中。生成的查询有助于缩小词汇搜索的词汇差距,因为生成的查询包含同义词。此外,它会重新加权单词,即使重要单词在段落中很少出现,也会给予更高的权重。在我们的BEIR论文中,我们证明了BM25 + docT5query是一个强大的搜索引擎。在BEIR仓库中,我们有一个如何将docT5query与Pyserini结合使用的示例。
- 特定领域训练数据生成:可用于生成训练数据以学习嵌入模型。在我们的GPL论文 / SBERT.net上的GPL示例中,我们有一个如何使用该模型为给定的未标记文本集合生成(查询,文本)对的示例。这些对可用于训练强大的密集嵌入模型。
📦 安装指南
文档未提及安装步骤,故跳过该章节。
💻 使用示例
基础用法
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
model_name = 'doc2query/msmarco-portuguese-mt5-base-v1'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
text = "Python é uma linguagem de programação de alto nível, interpretada de script, imperativa, orientada a objetos, funcional, de tipagem dinâmica e forte. Foi lançada por Guido van Rossum em 1991. Atualmente, possui um modelo de desenvolvimento comunitário, aberto e gerenciado pela organização sem fins lucrativos Python Software Foundation. Apesar de várias partes da linguagem possuírem padrões e especificações formais, a linguagem, como um todo, não é formalmente especificada. O padrão de facto é a implementação CPython."
def create_queries(para):
input_ids = tokenizer.encode(para, return_tensors='pt')
with torch.no_grad():
sampling_outputs = model.generate(
input_ids=input_ids,
max_length=64,
do_sample=True,
top_p=0.95,
top_k=10,
num_return_sequences=5
)
beam_outputs = model.generate(
input_ids=input_ids,
max_length=64,
num_beams=5,
no_repeat_ngram_size=2,
num_return_sequences=5,
early_stopping=True
)
print("Paragraph:")
print(para)
print("\nBeam Outputs:")
for i in range(len(beam_outputs)):
query = tokenizer.decode(beam_outputs[i], skip_special_tokens=True)
print(f'{i + 1}: {query}')
print("\nSampling Outputs:")
for i in range(len(sampling_outputs)):
query = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True)
print(f'{i + 1}: {query}')
create_queries(text)
注意:model.generate()
在使用top_k/top_n采样时是不确定的,每次运行时会产生不同的查询。
🔧 技术细节
该模型对google/mt5-base进行了66k个训练步骤的微调(对来自MS MARCO的500k个训练对进行了4个epoch的训练)。训练脚本请参考此仓库中的train_script.py
。
输入文本被截断为320个词块,输出文本最多生成64个词块。该模型在来自mMARCO数据集的(查询,段落)对上进行训练。
📄 许可证
本项目采用Apache-2.0许可证。