🚀 E5-base
E5-base模型通过弱监督对比预训练生成文本嵌入,可用于文本检索、语义相似度计算等任务。该模型有12层,嵌入大小为768。
🚀 快速开始
2023年5月消息:建议切换到 e5-base-v2,它性能更优且用法相同。
此模型基于论文 Text Embeddings by Weakly-Supervised Contrastive Pre-training 开发,作者包括 Liang Wang、Nan Yang 等,于2022年发表在 arXiv 上。
✨ 主要特性
- 高效嵌入:能够将文本高效地转换为768维的嵌入向量。
- 弱监督训练:采用弱监督对比预训练方法,提升模型性能。
- 多任务支持:支持检索、语义相似度计算等多种自然语言处理任务。
📦 安装指南
使用此模型前,需安装 transformers
库,可通过以下命令安装:
pip install transformers
若要使用 sentence_transformers
库,可执行以下命令:
pip install sentence_transformers~=2.2.2
💻 使用示例
基础用法
以下是一个对 MS-MARCO 段落排名数据集中的查询和段落进行编码的示例:
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoModel
def average_pool(last_hidden_states: Tensor,
attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
input_texts = ['query: how much protein should a female eat',
'query: summit define',
"passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."]
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base')
model = AutoModel.from_pretrained('intfloat/e5-base')
batch_dict = tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt')
outputs = model(**batch_dict)
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
embeddings = F.normalize(embeddings, p=2, dim=1)
scores = (embeddings[:2] @ embeddings[2:].T) * 100
print(scores.tolist())
高级用法
使用 sentence_transformers
库的示例:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('intfloat/e5-base')
input_texts = [
'query: how much protein should a female eat',
'query: summit define',
"passage: As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
"passage: Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
]
embeddings = model.encode(input_texts, normalize_embeddings=True)
📚 详细文档
输入文本前缀规则
- 对于非对称任务(如开放问答中的段落检索、即席信息检索),分别使用 "query: " 和 "passage: " 前缀。
- 对于对称任务(如语义相似度、释义检索),使用 "query: " 前缀。
- 若将嵌入用作特征(如线性探测分类、聚类),使用 "query: " 前缀。
🔧 技术细节
该模型有12层,嵌入大小为768。它通过弱监督对比预训练学习文本嵌入,使用 InfoNCE 对比损失,温度设置为 0.01。
📄 许可证
本模型采用 MIT 许可证。
📋 信息表格
属性 |
详情 |
模型类型 |
基于弱监督对比预训练的文本嵌入模型 |
训练数据 |
未详细说明 |
常用提示信息
⚠️ 重要提示
输入文本需添加 "query: " 或 "passage: " 前缀,否则模型性能会下降。
💡 使用建议
若复现结果与模型卡片中报告的结果略有不同,可能是 transformers
和 pytorch
版本不同导致的。对于文本嵌入任务,余弦相似度分数的相对顺序比绝对值更重要。
引用格式
如果您觉得我们的论文或模型有帮助,请按以下方式引用:
@article{wang2022text,
title={Text Embeddings by Weakly-Supervised Contrastive Pre-training},
author={Wang, Liang and Yang, Nan and Huang, Xiaolong and Jiao, Binxing and Yang, Linjun and Jiang, Daxin and Majumder, Rangan and Wei, Furu},
journal={arXiv preprint arXiv:2212.03533},
year={2022}
}
局限性
此模型仅适用于英文文本,长文本将被截断为最多512个标记。