🚀 multi-qa_v1-distilbert-mean_cos
SentenceTransformers 是一套模型和框架,可根据给定数据训练并生成句子嵌入向量。生成的句子嵌入向量可用于聚类、语义搜索等任务。本模型使用预训练的 distilbert-base-uncased 模型,并通过孪生网络设置和对比学习目标进行训练。我们使用 StackExchange 的问答对作为训练数据,使模型在问答嵌入相似度方面表现更稳健。对于此模型,我们使用隐藏状态的均值池化作为句子嵌入。
本模型由 Hugging Face 组织的 JAX/Flax 用于 NLP 和 CV 的社区周 期间开发。该模型是 使用 10 亿训练对训练有史以来最好的句子嵌入模型 项目的一部分。我们借助高效的硬件基础设施(7 个 TPU v3 - 8)以及谷歌 Flax、JAX 和云团队成员在高效深度学习框架方面的帮助来运行该项目。
🚀 快速开始
本模型旨在用作搜索引擎的句子编码器。给定输入句子,它将输出一个捕获句子语义信息的向量。该句子向量可用于语义搜索、聚类或句子相似度任务。
✨ 主要特性
- 作为句子编码器,能输出捕获句子语义信息的向量。
- 适用于语义搜索、聚类或句子相似度等任务。
📦 安装指南
此部分原文档未提及具体安装步骤,跳过。
💻 使用示例
基础用法
以下是如何使用 SentenceTransformers 库来获取给定文本特征的示例:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('flax-sentence-embeddings/multi-qa_v1-distilbert-mean_cos')
text = "Replace me by any question / answer you'd like."
text_embbedding = model.encode(text)
📚 详细文档
预期用途
我们的模型旨在用作搜索引擎的句子编码器。给定输入句子,它输出一个捕获句子语义信息的向量。该句子向量可用于语义搜索、聚类或句子相似度任务。
训练过程
预训练
我们使用预训练的 distilbert-base-uncased 模型。有关预训练过程的更多详细信息,请参考该模型的卡片。
微调
我们使用对比目标对模型进行微调。具体来说,我们计算批次中每个可能句子对的余弦相似度,然后通过与真实对进行比较来应用交叉熵损失。
超参数
我们在 TPU v3 - 8 上训练模型。我们使用 1024 的批次大小(每个 TPU 核心 128)进行 80k 步的训练。我们使用 500 的学习率预热。序列长度限制为 128 个标记。我们使用 AdamW 优化器,学习率为 2e - 5。完整的训练脚本可在当前存储库中找到。
训练数据
我们使用多个 Stackexchange 问答数据集的串联来微调我们的模型。还使用了 MSMARCO、NQ 等问答数据集。
数据集 |
论文 |
训练元组数量 |
Stack Exchange QA - Title & Answer |
- |
4,750,619 |
Stack Exchange |
- |
364,001 |
TriviaqQA |
- |
73,346 |
SQuAD2.0 |
paper |
87,599 |
Quora Question Pairs |
- |
103,663 |
Eli5 |
paper |
325,475 |
PAQ |
paper |
64,371,441 |
WikiAnswers |
paper |
77,427,422 |
MS MARCO |
paper |
9,144,553 |
GOOAQ: Open Question Answering with Diverse Answer Types |
paper |
3,012,496 |
Yahoo Answers Question/Answer |
paper |
681,164 |
SearchQA |
- |
582,261 |
Natural Questions (NQ) |
paper |
100,231 |
🔧 技术细节
本模型使用预训练的 distilbert-base-uncased 模型,通过孪生网络设置和对比学习目标进行训练。使用 StackExchange 的问答对作为训练数据,以提高模型在问答嵌入相似度方面的性能。在微调过程中,计算批次中句子对的余弦相似度并应用交叉熵损失。使用 AdamW 优化器和特定的超参数(如学习率、批次大小等)在 TPU v3 - 8 上进行训练。
📄 许可证
此部分原文档未提及许可证信息,跳过。