🚀 句子相似度模型
本项目旨在使用自监督对比学习目标,在非常大的句子级数据集上训练句子嵌入模型。通过微调预训练模型,我们得到了能够有效捕捉句子语义信息的模型,可用于信息检索、聚类或句子相似度任务。
🚀 快速开始
本模型可作为句子编码器使用。给定输入句子,它会输出一个捕捉句子语义信息的向量。句子向量可用于信息检索、聚类或句子相似度任务。
✨ 主要特性
- 强大的语义捕捉能力:输出的向量能够有效捕捉句子的语义信息。
- 多场景适用:可用于信息检索、聚类或句子相似度等多种任务。
📦 安装指南
使用此模型需要安装 SentenceTransformers 库。
💻 使用示例
基础用法
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('flax-sentence-embeddings/reddit_single-context_mpnet-base')
text = "Replace me by any text you'd like."
text_embbedding = model.encode(text)
📚 详细文档
模型描述
本项目旨在使用自监督对比学习目标,在非常大的句子级数据集上训练句子嵌入模型。我们使用了预训练的 'mpnet-base' 模型,并在 7 亿个句子对的数据集上进行了微调。我们采用对比学习目标:给定一对句子中的一个句子,模型应从一组随机采样的其他句子中预测出在数据集中实际与之配对的句子。
我们在由 Hugging Face 组织的 使用 JAX/Flax 进行自然语言处理和计算机视觉的社区周 期间开发了此模型。我们将此模型作为 使用 10 亿个训练对训练有史以来最好的句子嵌入模型 项目的一部分进行开发。我们受益于高效的硬件基础设施来运行该项目:7 个 TPU v3 - 8,以及谷歌的 Flax、JAX 和云团队成员在高效深度学习框架方面的指导。
预期用途
我们的模型旨在用作句子编码器。给定输入句子,它会输出一个捕捉句子语义信息的向量。句子向量可用于信息检索、聚类或句子相似度任务。
训练过程
预训练
我们使用预训练的 'mpnet-base' 模型。有关预训练过程的更详细信息,请参考模型卡片。
微调
我们使用对比目标对模型进行微调。形式上,我们从批次中的每个可能的句子对计算余弦相似度。然后通过与真实对进行比较来应用交叉熵损失。
超参数
我们在 TPU v3 - 8 上训练模型。我们使用 1024 的批次大小(每个 TPU 核心 128)进行了 540k 步的训练。我们使用了 500 的学习率预热。序列长度限制为 128 个标记。我们使用了 AdamW 优化器,学习率为 2e - 5。完整的训练脚本可在此存储库中找到。
训练数据
我们使用多个数据集的拼接来微调我们的模型。句子对的总数超过 7 亿个句子。我们根据加权概率对每个数据集进行采样,配置详情见 data_config.json
文件。我们在构建数据集时仅使用第一个上下文响应。
属性 |
详情 |
模型类型 |
基于预训练的'mpnet - base'模型微调的句子嵌入模型 |
训练数据 |
多个数据集拼接,总数超过 7 亿个句子对,具体数据集如 [Reddit conversationnal](https://github.com/PolyAI - LDN/conversational - datasets/tree/master/reddit) ,有 726,484,430 个训练元组 |