🚀 doc2query/msmarco-portuguese-mt5-base-v1
這是一個基於mT5的doc2query模型(也被稱為docT5query)。該模型可用於解決文檔擴展和特定領域訓練數據生成的問題,為信息檢索和模型訓練提供了強大的支持。
🚀 快速開始
此模型可用於以下兩個主要場景:
- 文檔擴展:為段落生成20 - 40個查詢,並將段落和生成的查詢索引到標準的BM25索引(如Elasticsearch、OpenSearch或Lucene)中。生成的查詢有助於縮小詞彙搜索的詞彙差距,因為生成的查詢包含同義詞。此外,它會重新加權單詞,即使重要單詞在段落中很少出現,也會給予更高的權重。在我們的BEIR論文中,我們證明了BM25 + docT5query是一個強大的搜索引擎。在BEIR倉庫中,我們有一個如何將docT5query與Pyserini結合使用的示例。
- 特定領域訓練數據生成:可用於生成訓練數據以學習嵌入模型。在我們的GPL論文 / SBERT.net上的GPL示例中,我們有一個如何使用該模型為給定的未標記文本集合生成(查詢,文本)對的示例。這些對可用於訓練強大的密集嵌入模型。
📦 安裝指南
文檔未提及安裝步驟,故跳過該章節。
💻 使用示例
基礎用法
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
model_name = 'doc2query/msmarco-portuguese-mt5-base-v1'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
text = "Python é uma linguagem de programação de alto nível, interpretada de script, imperativa, orientada a objetos, funcional, de tipagem dinâmica e forte. Foi lançada por Guido van Rossum em 1991. Atualmente, possui um modelo de desenvolvimento comunitário, aberto e gerenciado pela organização sem fins lucrativos Python Software Foundation. Apesar de várias partes da linguagem possuírem padrões e especificações formais, a linguagem, como um todo, não é formalmente especificada. O padrão de facto é a implementação CPython."
def create_queries(para):
input_ids = tokenizer.encode(para, return_tensors='pt')
with torch.no_grad():
sampling_outputs = model.generate(
input_ids=input_ids,
max_length=64,
do_sample=True,
top_p=0.95,
top_k=10,
num_return_sequences=5
)
beam_outputs = model.generate(
input_ids=input_ids,
max_length=64,
num_beams=5,
no_repeat_ngram_size=2,
num_return_sequences=5,
early_stopping=True
)
print("Paragraph:")
print(para)
print("\nBeam Outputs:")
for i in range(len(beam_outputs)):
query = tokenizer.decode(beam_outputs[i], skip_special_tokens=True)
print(f'{i + 1}: {query}')
print("\nSampling Outputs:")
for i in range(len(sampling_outputs)):
query = tokenizer.decode(sampling_outputs[i], skip_special_tokens=True)
print(f'{i + 1}: {query}')
create_queries(text)
注意:model.generate()
在使用top_k/top_n採樣時是不確定的,每次運行時會產生不同的查詢。
🔧 技術細節
該模型對google/mt5-base進行了66k個訓練步驟的微調(對來自MS MARCO的500k個訓練對進行了4個epoch的訓練)。訓練腳本請參考此倉庫中的train_script.py
。
輸入文本被截斷為320個詞塊,輸出文本最多生成64個詞塊。該模型在來自mMARCO數據集的(查詢,段落)對上進行訓練。
📄 許可證
本項目採用Apache-2.0許可證。