模型简介
模型特点
模型能力
使用案例
🚀 multi-qa-mpnet-base-dot-v1
这是一个句子转换器模型,它能将句子和段落映射到768维的密集向量空间,专为语义搜索而设计。该模型在来自不同来源的2.15亿个(问题,答案)对上进行了训练。若想了解语义搜索的相关介绍,请查看:SBERT.net - 语义搜索。
🚀 快速开始
本模型可用于语义搜索,它能将查询语句和文本段落编码到一个密集向量空间中,从而为给定的段落找到相关文档。
✨ 主要特性
- 能够将句子和段落映射到768维的密集向量空间。
- 专为语义搜索设计,在2.15亿个(问题,答案)对上进行训练。
📦 安装指南
若你已安装句子转换器,使用此模型会非常简单:
pip install -U sentence-transformers
💻 使用示例
基础用法
from sentence_transformers import SentenceTransformer, util
query = "How many people live in London?"
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
#Load the model
model = SentenceTransformer('sentence-transformers/multi-qa-mpnet-base-dot-v1')
#Encode query and documents
query_emb = model.encode(query)
doc_emb = model.encode(docs)
#Compute dot score between query and all document embeddings
scores = util.dot_score(query_emb, doc_emb)[0].cpu().tolist()
#Combine docs & scores
doc_score_pairs = list(zip(docs, scores))
#Sort by decreasing score
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
#Output passages & scores
for doc, score in doc_score_pairs:
print(score, doc)
高级用法
若未安装句子转换器,可按以下方式使用该模型:首先,将输入数据传入转换器模型,然后对上下文词嵌入应用正确的池化操作。
from transformers import AutoTokenizer, AutoModel
import torch
#CLS Pooling - Take output from first token
def cls_pooling(model_output):
return model_output.last_hidden_state[:,0]
#Encode text
def encode(texts):
# Tokenize sentences
encoded_input = tokenizer(texts, padding=True, truncation=True, return_tensors='pt')
# Compute token embeddings
with torch.no_grad():
model_output = model(**encoded_input, return_dict=True)
# Perform pooling
embeddings = cls_pooling(model_output)
return embeddings
# Sentences we want sentence embeddings for
query = "How many people live in London?"
docs = ["Around 9 Million people live in London", "London is known for its financial district"]
# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1")
model = AutoModel.from_pretrained("sentence-transformers/multi-qa-mpnet-base-dot-v1")
#Encode query and docs
query_emb = encode(query)
doc_emb = encode(docs)
#Compute dot score between query and all document embeddings
scores = torch.mm(query_emb, doc_emb.transpose(0, 1))[0].cpu().tolist()
#Combine docs & scores
doc_score_pairs = list(zip(docs, scores))
#Sort by decreasing score
doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
#Output passages & scores
for doc, score in doc_score_pairs:
print(score, doc)
🔧 技术细节
以下是该模型使用的一些技术细节:
属性 | 详情 |
---|---|
维度 | 768 |
是否生成归一化嵌入 | 否 |
池化方法 | CLS池化 |
合适的得分函数 | 点积(例如 util.dot_score ) |
📚 详细文档
背景
该项目旨在使用自监督对比学习目标,在非常大的句子级数据集上训练句子嵌入模型。采用对比学习目标:给定一对句子中的一个,模型应从一组随机采样的其他句子中预测出在数据集中实际与之配对的句子。
此模型是在由Hugging Face组织的使用JAX/Flax进行自然语言处理和计算机视觉的社区周期间开发的。该模型是使用10亿个训练对训练有史以来最好的句子嵌入模型项目的一部分。项目运行得益于高效的硬件基础设施:7个TPU v3 - 8,以及谷歌Flax、JAX和云团队成员在高效深度学习框架方面的指导。
预期用途
该模型旨在用于语义搜索:它将查询/问题和文本段落编码到一个密集向量空间中,为给定的段落找到相关文档。
请注意,词块数量限制为512:超过此长度的文本将被截断。此外,该模型仅在最多250个词块的输入文本上进行了训练,对于较长的文本可能效果不佳。
训练过程
完整的训练脚本可在当前仓库中找到:train_script.py
。
预训练
使用预训练的mpnet-base
模型。有关预训练过程的更详细信息,请参考该模型的卡片。
训练
使用多个数据集的拼接来微调模型。总共约有2.15亿个(问题,答案)对。
每个数据集按照加权概率进行采样,具体配置详见data_config.json
文件。
该模型使用MultipleNegativesRankingLoss进行训练,采用CLS池化、点积作为相似性函数,缩放比例为1。
数据集 | 训练元组数量 |
---|---|
WikiAnswers 来自WikiAnswers的重复问题对 | 77,427,422 |
PAQ 为维基百科中的每个段落自动生成的(问题,段落)对 | 64,371,441 |
Stack Exchange 来自所有StackExchanges的(标题,正文)对 | 25,316,456 |
Stack Exchange 来自所有StackExchanges的(标题,答案)对 | 21,396,559 |
MS MARCO 来自必应搜索引擎的50万个查询的三元组(查询,答案,硬负样本) | 17,579,773 |
GOOAQ: Open Question Answering with Diverse Answer Types 300万个谷歌查询和谷歌特色片段的(查询,答案)对 | 3,012,496 |
Amazon-QA 来自亚马逊产品页面的(问题,答案)对 | 2,448,839 |
Yahoo Answers 来自雅虎问答的(标题,答案)对 | 1,198,260 |
Yahoo Answers 来自雅虎问答的(问题,答案)对 | 681,164 |
Yahoo Answers 来自雅虎问答的(标题,问题)对 | 659,896 |
SearchQA 14万个问题的(问题,答案)对,每个问题有前5个谷歌片段 | 582,261 |
ELI5 来自Reddit ELI5(像解释给五岁小孩一样解释)的(问题,答案)对 | 325,475 |
Stack Exchange 重复问题对(标题) | 304,525 |
Quora Question Triplets 来自Quora问题对数据集的三元组(问题,重复问题,硬负样本) | 103,663 |
Natural Questions (NQ) 10万个真实谷歌查询与相关维基百科段落的(问题,段落)对 | 100,231 |
SQuAD2.0 来自SQuAD2.0数据集的(问题,段落)对 | 87,599 |
TriviaQA (问题,证据)对 | 73,346 |
总计 | 214,988,242 |
⚠️ 重要提示
词块数量限制为512,超过此长度的文本将被截断。该模型仅在最多250个词块的输入文本上进行了训练,对于较长的文本可能效果不佳。
💡 使用建议
在使用模型进行语义搜索时,尽量确保输入文本的词块数量在250以内,以获得更好的效果。







