Splade PP En V1
模型简介
模型特点
模型能力
使用案例
🚀 独立实现适用于工业场景的SPLADE++模型
本项目独立实现了SPLADE++模型(又名splade - cocondenser*及其系列),并针对工业场景进行了一些效率调整。该模型结合了两项强大研究的优势,旨在平衡检索效果(质量)和检索效率(延迟和成本),为工业应用提供更合适的解决方案。
🚀 快速开始
本工作借鉴了两项重要研究:Naver的《From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective》论文和Google的SparseEmbed。感谢这两个团队的杰出工作。
✨ 主要特性
1. 稀疏表示与语义搜索的结合
- 结合了词法搜索和语义搜索的优点,学习具有一定可解释性的查询和文档的稀疏表示。
- 稀疏表示可作为查询和文档的隐式或显式(潜在、上下文相关)扩展机制。
2. 工业场景的效率优化
- 进行了非常细微的检索效率调整,使其更适合工业环境。
- 实现了检索效果(质量)和检索效率(延迟和成本)的良好平衡。
3. 模型性能表现
- 在ID数据上达到了MRR@10为37.22的竞争力效果(OOD为48.7),且检索延迟为47.27ms(多线程)。
- 在消费级GPU上,每个查询仅使用5个负样本即可实现上述性能。
📦 安装指南
暂未提供明确的安装步骤,可参考后续使用部分结合自身环境进行配置。
💻 使用示例
基础用法
# 以下是使用SPLADERunner库进行文档扩展的示例
pip install spladerunner
#One-time init
from spladerunner import Expander
# Default model is the document expander.
exapander = Expander()
#Sample Document expansion
sparse_rep = expander.expand(
["The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science."])
高级用法
# 使用HuggingFace库进行模型调用的完整代码示例
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained('prithivida/Splade_PP_en_v1')
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}
model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v1')
model.to(device)
sentence = """The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science."""
inputs = tokenizer(sentence, return_tensors='pt')
inputs = {key: val.to(device) for key, val in inputs.items()}
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
outputs = model(**inputs)
logits, attention_mask = outputs.logits, attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
vector = max_val.squeeze()
cols = vector.nonzero().squeeze().cpu().tolist()
print("number of actual dimensions: ", len(cols))
weights = vector[cols].cpu().tolist()
d = {k: v for k, v in zip(cols, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
bow_rep.append((reverse_voc[k], round(v,2)))
print("SPLADE BOW rep:\n", bow_rep)
📚 详细文档
1. 什么是稀疏表示以及为什么要学习它
展开查看详细内容
词法搜索
基于词袋(BOW)的稀疏向量进行词法搜索是强大的基线方法,但存在词汇不匹配问题,只能进行精确的词项匹配。其优缺点如下:
- ✅ 高效且成本低。
- ✅ 无需微调模型。
- ✅ 可解释性强。
- ✅ 精确的词项匹配。
- ❌ 词汇不匹配(需要记住精确的词项)。
语义搜索
学习型神经/密集检索器(如DPR、Sentence transformers*、BGE*模型)结合近似最近邻搜索已取得显著成果。其优缺点如下:
- ✅ 搜索方式符合人类的自然思维。
- ✅ 微调后在性能上远超稀疏搜索。
- ✅ 易于处理多模态数据。
- ❌ 存在词项遗忘问题(错过词项匹配)。
- ❌ 资源密集(索引和检索都需要大量资源)。
- ❌ 著名的难以解释。
- ❌ 对于分布外(OOD)数据需要微调。
核心理念
结合两种搜索的优点,引发了人们对学习具有一定可解释性的查询和文档的稀疏表示的兴趣。稀疏表示还可作为查询和文档的隐式或显式(潜在、上下文相关)扩展机制。
稀疏模型学习内容
模型学习将其学习到的密集表示投影到掩码语言模型(MLM)头部,以给出词汇分布。也就是说,模型可以进行自动词项扩展。
2. 动机
SPLADE模型在检索效果(质量)和检索效率(延迟和成本)之间取得了很好的平衡。在此基础上,我们进行了非常细微的检索效率调整,使其更适合工业环境。
3. 为什么FLOPS是工业场景的关键指标
展开查看详细内容
我们的模型在实现相当竞争力的检索效果的同时,使用的令牌数量明显少于其他模型,从而减少了FLOPS。以下是不同模型的对比示例:我们的模型
number of actual dimensions: 113
SPLADE BOW rep:
[('stress', 2.36), ('glass', 2.15), ('thermal', 2.06), ('pan', 1.83), ('glasses', 1.67), ('break', 1.47), ('crack', 1.47), ('heat', 1.45), ('warmth', 1.36), ('depression', 1.34), ('hotter', 1.23), ('hottest', 1.11), ('window', 1.11), ('hot', 1.1), ('area', 1.04), ('cause', 1.01), ('adjacent', 0.99), ('too', 0.94), ('created', 0.86), ('##pan', 0.84), ('phenomenon', 0.81), ('when', 0.78), ('temperature', 0.76), ('cracked', 0.75), ('factors', 0.74), ('windows', 0.72), ('create', 0.71), ('level', 0.7), ('formed', 0.61), ('stresses', 0.59), ('warm', 0.58), ('fracture', 0.57), ('adjoining', 0.56), ('areas', 0.56), ('nearby', 0.56), ('causes', 0.56), ('broken', 0.54), ('produced', 0.52), ('sash', 0.51), ('if', 0.51), ('breaks', 0.49), ('is', 0.49), ('effect', 0.45), ('heated', 0.44), ('process', 0.42), ('breaking', 0.42), ('one', 0.4), ('mirror', 0.39), ('factor', 0.38), ('shatter', 0.38), ('formation', 0.37), ('mathias', 0.37), ('damage', 0.36), ('cracking', 0.35), ('climate', 0.35), ('ceramic', 0.34), ('reaction', 0.34), ('steam', 0.33), ('reflection', 0.33), ('generated', 0.33), ('material', 0.32), ('burst', 0.31), ('fire', 0.31), ('neighboring', 0.3), ('explosion', 0.29), ('caused', 0.29), ('warmer', 0.29), ('because', 0.28), ('anxiety', 0.28), ('furnace', 0.28), ('tear', 0.27), ('induced', 0.27), ('fail', 0.26), ('are', 0.26), ('collapse', 0.26), ('##thermal', 0.26), ('and', 0.25), ('great', 0.25), ('get', 0.24), ('spark', 0.23), ('lens', 0.2), ('cooler', 0.19), ('determined', 0.19), ('leak', 0.19), ('disease', 0.19), ('emotion', 0.16), ('cork', 0.14), ('cooling', 0.14), ('heating', 0.13), ('governed', 0.13), ('optical', 0.12), ('surrounding', 0.12), ('warming', 0.12), ('convection', 0.11), ('regulated', 0.11), ('problem', 0.1), ('cool', 0.09), ('violence', 0.09), ('breaker', 0.09), ('image', 0.09), ('photo', 0.05), ('strike', 0.05), ('.', 0.04), ('shattering', 0.04), ('snap', 0.03), ('wilson', 0.03), ('weather', 0.02), ('eye', 0.02), ('produce', 0.01), ('crime', 0.01), ('humid', 0.0), ('impact', 0.0), ('earthquake', 0.0)]
naver/splade - cocondenser - ensembledistil (当前最优,令牌数量多约10%,FLOPS = 1.85)
number of actual dimensions: 126
SPLADE BOW rep:
[('stress', 2.25), ('glass', 2.23), ('thermal', 2.18), ('glasses', 1.65), ('pan', 1.62), ('heat', 1.56), ('stressed', 1.42), ('crack', 1.31), ('break', 1.12), ('cracked', 1.1), ('hot', 0.93), ('created', 0.9), ('factors', 0.81), ('broken', 0.73), ('caused', 0.71), ('too', 0.71), ('damage', 0.69), ('if', 0.68), ('hotter', 0.65), ('governed', 0.61), ('heating', 0.59), ('temperature', 0.59), ('adjacent', 0.59), ('cause', 0.58), ('effect', 0.57), ('fracture', 0.56), ('bradford', 0.55), ('strain', 0.53), ('hammer', 0.51), ('brian', 0.48), ('error', 0.47), ('windows', 0.45), ('will', 0.45), ('reaction', 0.42), ('create', 0.42), ('windshield', 0.41), ('heated', 0.41), ('factor', 0.4), ('cracking', 0.39), ('failure', 0.38), ('mechanical', 0.38), ('when', 0.38), ('formed', 0.38), ('bolt', 0.38), ('mechanism', 0.37), ('warm', 0.37), ('areas', 0.36), ('area', 0.36), ('energy', 0.34), ('disorder', 0.33), ('barry', 0.33), ('shock', 0.32), ('determined', 0.32), ('gage', 0.32), ('sash', 0.31), ('theory', 0.31), ('level', 0.31), ('resistant', 0.31), ('brake', 0.3), ('window', 0.3), ('crash', 0.3), ('hazard', 0.29), ('##ink', 0.27), ('ceramic', 0.27), ('storm', 0.25), ('problem', 0.25), ('issue', 0.24), ('impact', 0.24), ('fridge', 0.24), ('injury', 0.23), ('ross', 0.22), ('causes', 0.22), ('affect', 0.21), ('pressure', 0.21), ('fatigue', 0.21), ('leak', 0.21), ('eye', 0.2), ('frank', 0.2), ('cool', 0.2), ('might', 0.19), ('gravity', 0.18), ('ray', 0.18), ('static', 0.18), ('collapse', 0.18), ('physics', 0.18), ('wave', 0.18), ('reflection', 0.17), ('parker', 0.17), ('strike', 0.17), ('hottest', 0.17), ('burst', 0.16), ('chance', 0.16), ('burn', 0.14), ('rubbing', 0.14), ('interference', 0.14), ('bailey', 0.13), ('vibration', 0.12), ('gilbert', 0.12), ('produced', 0.12), ('rock', 0.12), ('warmer', 0.11), ('get', 0.11), ('drink', 0.11), ('fireplace', 0.11), ('ruin', 0.1), ('brittle', 0.1), ('fragment', 0.1), ('stumble', 0.09), ('formation', 0.09), ('shatter', 0.08), ('great', 0.08), ('friction', 0.08), ('flash', 0.07), ('cracks', 0.07), ('levels', 0.07), ('smash', 0.04), ('fail', 0.04), ('fra', 0.04), ('##glass', 0.03), ('variables', 0.03), ('because', 0.02), ('knock', 0.02), ('sun', 0.02), ('crush', 0.01), ('##e', 0.01), ('anger', 0.01)]
naver/splade - v2 - distil (令牌数量多约100%,FLOPS = 3.82)
number of actual dimensions: 234
SPLADE BOW rep:
[('glass', 2.55), ('stress', 2.39), ('thermal', 2.38), ('glasses', 1.95), ('stressed', 1.87), ('crack', 1.84), ('cool', 1.78), ('heat', 1.62), ('pan', 1.6), ('break', 1.53), ('adjacent', 1.44), ('hotter', 1.43), ('strain', 1.21), ('area', 1.16), ('adjoining', 1.14), ('heated', 1.11), ('window', 1.07), ('stresses', 1.04), ('hot', 1.03), ('created', 1.03), ('create', 1.03), ('cause', 1.02), ('factors', 1.02), ('cooler', 1.01), ('broken', 1.0), ('too', 0.99), ('fracture', 0.96), ('collapse', 0.96), ('cracking', 0.95), ('great', 0.93), ('happen', 0.93), ('windows', 0.89), ('broke', 0.87), ('##e', 0.87), ('pressure', 0.84), ('hottest', 0.84), ('breaking', 0.83), ('govern', 0.79), ('shatter', 0.76), ('level', 0.75), ('heating', 0.69), ('temperature', 0.69), ('cracked', 0.69), ('panel', 0.68), ('##glass', 0.68), ('ceramic', 0.67), ('sash', 0.66), ('warm', 0.66), ('areas', 0.64), ('creating', 0.63), ('will', 0.62), ('tension', 0.61), ('cracks', 0.61), ('optical', 0.6), ('mechanism', 0.58), ('kelly', 0.58), ('determined', 0.58), ('generate', 0.58), ('causes', 0.56), ('if', 0.56), ('factor', 0.56), ('the', 0.56), ('chemical', 0.55), ('governed', 0.55), ('crystal', 0.55), ('strike', 0.55), ('microsoft', 0.54), ('creates', 0.53), ('than', 0.53), ('relation', 0.53), ('glazed', 0.52), ('compression', 0.51), ('painting', 0.51), ('governing', 0.5), ('harden', 0.49), ('solar', 0.48), ('reflection', 0.48), ('ic', 0.46), ('split', 0.45), ('mirror', 0.44), ('damage', 0.43), ('ring', 0.42), ('formation', 0.42), ('wall', 0.41), ('burst', 0.4), ('radiant', 0.4), ('determine', 0.4), ('one', 0.4), ('plastic', 0.39), ('furnace', 0.39), ('difference', 0.39), ('melt', 0.39), ('get', 0.39), ('contract', 0.38), ('forces', 0.38), ('gets', 0.38), ('produce', 0.38), ('surrounding', 0.37), ('vibration', 0.37), ('tile', 0.37), ('fail', 0.36), ('warmer', 0.36), ('rock', 0.35), ('fault', 0.35), ('roof', 0.34), ('burned', 0.34), ('physics', 0.33), ('welding', 0.33), ('why', 0.33), ('a', 0.32), ('pop', 0.32), ('and', 0.31), ('fra', 0.3), ('stat', 0.3), ('withstand', 0.3), ('sunglasses', 0.3), ('material', 0.29), ('ice', 0.29), ('generated', 0.29), ('matter', 0.29), ('frame', 0.28), ('elements', 0.28), ('then', 0.28), ('.', 0.28), ('pont', 0.28), ('blow', 0.28), ('snap', 0.27), ('metal', 0.26), ('effect', 0.26), ('reaction', 0.26), ('related', 0.25), ('aluminium', 0.25), ('neighboring', 0.25), ('weight', 0.25), ('steel', 0.25), ('bulb', 0.25), ('tear', 0.25), ('coating', 0.25), ('plumbing', 0.25), ('co', 0.25), ('microwave', 0.24), ('formed', 0.24), ('pipe', 0.23), ('drink', 0.23), ('chemistry', 0.23), ('energy', 0.22), ('reflect', 0.22), ('dynamic', 0.22), ('leak', 0.22), ('is', 0.22), ('lens', 0.21), ('frost', 0.21), ('lenses', 0.21), ('produced', 0.21), ('induced', 0.2), ('arise', 0.2), ('plate', 0.2), ('equations', 0.19), ('affect', 0.19), ('tired', 0.19), ('mirrors', 0.18), ('thickness', 0.18), ('bending', 0.18), ('cabinet', 0.17), ('apart', 0.17), ('##thermal', 0.17), ('gas', 0.17), ('equation', 0.17), ('relationship', 0.17), ('composition', 0.17), ('engineering', 0.17), ('block', 0.16), ('breaks', 0.16), ('when', 0.16), ('definition', 0.16), ('collapsed', 0.16), ('generation', 0.16), (',', 0.16), ('philips', 0.16), ('later', 0.15), ('wood', 0.15), ('neighbouring', 0.15), ('structural', 0.14), ('regulate', 0.14), ('neighbors', 0.13), ('lighting', 0.13), ('happens', 0.13), ('more', 0.13), ('property', 0.13), ('cooling', 0.12), ('shattering', 0.12), ('melting', 0.12), ('how', 0.11), ('cloud', 0.11), ('barriers', 0.11), ('lam', 0.11), ('conditions', 0.11), ('rule', 0.1), ('insulation', 0.1), ('bathroom', 0.09), ('convection', 0.09), ('cavity', 0.09), ('source', 0.08), ('properties', 0.08), ('bend', 0.08), ('bottles', 0.08), ('ceramics', 0.07), ('temper', 0.07), ('tense', 0.07), ('keller', 0.07), ('breakdown', 0.07), ('concrete', 0.07), ('simon', 0.07), ('solids', 0.06), ('windshield', 0.05), ('eye', 0.05), ('sunlight', 0.05), ('brittle', 0.03), ('caused', 0.03), ('suns', 0.03), ('floor', 0.02), ('components', 0.02), ('photo', 0.02), ('change', 0.02), ('sun', 0.01), ('crystals', 0.01), ('problem', 0.01), ('##proof', 0.01), ('parameters', 0.01), ('gases', 0.0), ('prism', 0.0), ('doing', 0.0), ('lattice', 0.0), ('ground', 0.0)]
- 注意1:此特定段落用作比较示例
4. 如何转化为实证指标
我们的模型令牌稀疏但有效,这意味着更快的检索速度(用户体验)和更小的索引大小(成本)。以下是在标准MS - MARCO小开发集上的平均检索时间和缩放后的总FLOPS损失指标:
展开查看更多说明
注意:为什么选择Anserini而不是PISA? Anserini是一个基于Lucene的生产就绪库。常见的工业搜索部署使用基于Lucene的Solr或Elastic,因此性能具有可比性。PISA的延迟对于工业应用无关紧要,因为它只是一个研究系统。完整的Anserini评估日志包含编码、索引和查询的详细信息。
- BEIR ZST OOD性能:将添加到页面末尾。
我们的模型在以下方面有所不同
- 共冷凝器权重:与官方最优的SPLADE++或SparseEmbed不同,我们不使用Luyu/co - condenser*模型初始化权重,但仍达到了CoCondenser SPLADE级别的性能。
- 相同大小的模型:官方SPLADE++、SparseEmbed和我们的模型都在相同大小的基础模型上进行微调,即
bert - base - uncased
。
5. 工业适用性的路线图和未来方向
- 提高效率:持续改进服务和检索效率。
- 自定义/领域微调:探索如何在不进行昂贵标注的情况下,在自定义数据集或领域上进行经济有效的微调。
- 多语言SPLADE:研究如何将模型扩展到多语言环境,解决多语言模型训练成本高的问题。
6. 使用方法
与流行向量数据库结合使用
向量数据库 | Colab链接 |
---|---|
Pinecone | [ |
Qdrant | 待确定 |
使用SPLADERunner库
pip install spladerunner
#One-time init
from spladerunner import Expander
# Default model is the document expander.
exapander = Expander()
#Sample Document expansion
sparse_rep = expander.expand(
["The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science."])
使用HuggingFace
笔记本用户:先登录
!huggingface-cli login
集成到代码中 [如何在代码中使用HF令牌](https://huggingface.co/docs/hub/en/security - tokens) 进行如下更改:
tokenizer = AutoTokenizer.from_pretrained('prithivida/Splade_PP_en_v1', token=<Your token>)
model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v1', token=<Your token>)
完整代码
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained('prithivida/Splade_PP_en_v1')
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}
model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v1')
model.to(device)
sentence = """The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science."""
inputs = tokenizer(sentence, return_tensors='pt')
inputs = {key: val.to(device) for key, val in inputs.items()}
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
outputs = model(**inputs)
logits, attention_mask = outputs.logits, attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
vector = max_val.squeeze()
cols = vector.nonzero().squeeze().cpu().tolist()
print("number of actual dimensions: ", len(cols))
weights = vector[cols].cpu().tolist()
d = {k: v for k, v in zip(cols, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
bow_rep.append((reverse_voc[k], round(v,2)))
print("SPLADE BOW rep:\n", bow_rep)
BEIR零样本OOD性能
训练细节
待确定
🔧 技术细节
模型优化策略
- FLOPS调整:采用与官方SPLADE++不同的序列长度和严格受限的FLOPS调度及令牌预算,文档为128,查询为24,而非256,灵感来源于Google的SparseEmbed。
- 初始化权重:使用Vanilla
bert - base - uncased
,不像官方SPLADE++ / ColBERT那样依赖语料库感知。
性能对比分析
与其他模型相比,在保证检索效果的同时,通过减少令牌数量和FLOPS,实现了更快的检索速度和更小的索引大小。
📄 许可证
本项目采用Apache 2.0许可证。
致谢
- 感谢Nils Reimers提供的所有建议。
- 感谢Anserini库的作者。
局限性和偏差
BERT模型的所有局限性和偏差同样适用于本微调工作。
引用
如果您使用了我们的模型或库,请进行引用,引用信息如下:
Damodaran, P. (2024). Splade_PP_en_v1: Independent Implementation of SPLADE++ Model (`a.k.a splade - cocondenser* and family`) for the Industry setting. (Version 1.0.0) [Computer software].







