Splade PP En V1
模型概述
模型特點
模型能力
使用案例
🚀 獨立實現適用於工業場景的SPLADE++模型
本項目獨立實現了SPLADE++模型(又名splade - cocondenser*及其系列),並針對工業場景進行了一些效率調整。該模型結合了兩項強大研究的優勢,旨在平衡檢索效果(質量)和檢索效率(延遲和成本),為工業應用提供更合適的解決方案。
🚀 快速開始
本工作借鑑了兩項重要研究:Naver的《From Distillation to Hard Negative Sampling: Making Sparse Neural IR Models More Effective》論文和Google的SparseEmbed。感謝這兩個團隊的傑出工作。
✨ 主要特性
1. 稀疏表示與語義搜索的結合
- 結合了詞法搜索和語義搜索的優點,學習具有一定可解釋性的查詢和文檔的稀疏表示。
- 稀疏表示可作為查詢和文檔的隱式或顯式(潛在、上下文相關)擴展機制。
2. 工業場景的效率優化
- 進行了非常細微的檢索效率調整,使其更適合工業環境。
- 實現了檢索效果(質量)和檢索效率(延遲和成本)的良好平衡。
3. 模型性能表現
- 在ID數據上達到了MRR@10為37.22的競爭力效果(OOD為48.7),且檢索延遲為47.27ms(多線程)。
- 在消費級GPU上,每個查詢僅使用5個負樣本即可實現上述性能。
📦 安裝指南
暫未提供明確的安裝步驟,可參考後續使用部分結合自身環境進行配置。
💻 使用示例
基礎用法
# 以下是使用SPLADERunner庫進行文檔擴展的示例
pip install spladerunner
#One-time init
from spladerunner import Expander
# Default model is the document expander.
exapander = Expander()
#Sample Document expansion
sparse_rep = expander.expand(
["The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science."])
高級用法
# 使用HuggingFace庫進行模型調用的完整代碼示例
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained('prithivida/Splade_PP_en_v1')
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}
model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v1')
model.to(device)
sentence = """The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science."""
inputs = tokenizer(sentence, return_tensors='pt')
inputs = {key: val.to(device) for key, val in inputs.items()}
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
outputs = model(**inputs)
logits, attention_mask = outputs.logits, attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
vector = max_val.squeeze()
cols = vector.nonzero().squeeze().cpu().tolist()
print("number of actual dimensions: ", len(cols))
weights = vector[cols].cpu().tolist()
d = {k: v for k, v in zip(cols, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
bow_rep.append((reverse_voc[k], round(v,2)))
print("SPLADE BOW rep:\n", bow_rep)
📚 詳細文檔
1. 什麼是稀疏表示以及為什麼要學習它
展開查看詳細內容
詞法搜索
基於詞袋(BOW)的稀疏向量進行詞法搜索是強大的基線方法,但存在詞彙不匹配問題,只能進行精確的詞項匹配。其優缺點如下:
- ✅ 高效且成本低。
- ✅ 無需微調模型。
- ✅ 可解釋性強。
- ✅ 精確的詞項匹配。
- ❌ 詞彙不匹配(需要記住精確的詞項)。
語義搜索
學習型神經/密集檢索器(如DPR、Sentence transformers*、BGE*模型)結合近似最近鄰搜索已取得顯著成果。其優缺點如下:
- ✅ 搜索方式符合人類的自然思維。
- ✅ 微調後在性能上遠超稀疏搜索。
- ✅ 易於處理多模態數據。
- ❌ 存在詞項遺忘問題(錯過詞項匹配)。
- ❌ 資源密集(索引和檢索都需要大量資源)。
- ❌ 著名的難以解釋。
- ❌ 對於分佈外(OOD)數據需要微調。
核心理念
結合兩種搜索的優點,引發了人們對學習具有一定可解釋性的查詢和文檔的稀疏表示的興趣。稀疏表示還可作為查詢和文檔的隱式或顯式(潛在、上下文相關)擴展機制。
稀疏模型學習內容
模型學習將其學習到的密集表示投影到掩碼語言模型(MLM)頭部,以給出詞彙分佈。也就是說,模型可以進行自動詞項擴展。
2. 動機
SPLADE模型在檢索效果(質量)和檢索效率(延遲和成本)之間取得了很好的平衡。在此基礎上,我們進行了非常細微的檢索效率調整,使其更適合工業環境。
3. 為什麼FLOPS是工業場景的關鍵指標
展開查看詳細內容
我們的模型在實現相當競爭力的檢索效果的同時,使用的令牌數量明顯少於其他模型,從而減少了FLOPS。以下是不同模型的對比示例:我們的模型
number of actual dimensions: 113
SPLADE BOW rep:
[('stress', 2.36), ('glass', 2.15), ('thermal', 2.06), ('pan', 1.83), ('glasses', 1.67), ('break', 1.47), ('crack', 1.47), ('heat', 1.45), ('warmth', 1.36), ('depression', 1.34), ('hotter', 1.23), ('hottest', 1.11), ('window', 1.11), ('hot', 1.1), ('area', 1.04), ('cause', 1.01), ('adjacent', 0.99), ('too', 0.94), ('created', 0.86), ('##pan', 0.84), ('phenomenon', 0.81), ('when', 0.78), ('temperature', 0.76), ('cracked', 0.75), ('factors', 0.74), ('windows', 0.72), ('create', 0.71), ('level', 0.7), ('formed', 0.61), ('stresses', 0.59), ('warm', 0.58), ('fracture', 0.57), ('adjoining', 0.56), ('areas', 0.56), ('nearby', 0.56), ('causes', 0.56), ('broken', 0.54), ('produced', 0.52), ('sash', 0.51), ('if', 0.51), ('breaks', 0.49), ('is', 0.49), ('effect', 0.45), ('heated', 0.44), ('process', 0.42), ('breaking', 0.42), ('one', 0.4), ('mirror', 0.39), ('factor', 0.38), ('shatter', 0.38), ('formation', 0.37), ('mathias', 0.37), ('damage', 0.36), ('cracking', 0.35), ('climate', 0.35), ('ceramic', 0.34), ('reaction', 0.34), ('steam', 0.33), ('reflection', 0.33), ('generated', 0.33), ('material', 0.32), ('burst', 0.31), ('fire', 0.31), ('neighboring', 0.3), ('explosion', 0.29), ('caused', 0.29), ('warmer', 0.29), ('because', 0.28), ('anxiety', 0.28), ('furnace', 0.28), ('tear', 0.27), ('induced', 0.27), ('fail', 0.26), ('are', 0.26), ('collapse', 0.26), ('##thermal', 0.26), ('and', 0.25), ('great', 0.25), ('get', 0.24), ('spark', 0.23), ('lens', 0.2), ('cooler', 0.19), ('determined', 0.19), ('leak', 0.19), ('disease', 0.19), ('emotion', 0.16), ('cork', 0.14), ('cooling', 0.14), ('heating', 0.13), ('governed', 0.13), ('optical', 0.12), ('surrounding', 0.12), ('warming', 0.12), ('convection', 0.11), ('regulated', 0.11), ('problem', 0.1), ('cool', 0.09), ('violence', 0.09), ('breaker', 0.09), ('image', 0.09), ('photo', 0.05), ('strike', 0.05), ('.', 0.04), ('shattering', 0.04), ('snap', 0.03), ('wilson', 0.03), ('weather', 0.02), ('eye', 0.02), ('produce', 0.01), ('crime', 0.01), ('humid', 0.0), ('impact', 0.0), ('earthquake', 0.0)]
naver/splade - cocondenser - ensembledistil (當前最優,令牌數量多約10%,FLOPS = 1.85)
number of actual dimensions: 126
SPLADE BOW rep:
[('stress', 2.25), ('glass', 2.23), ('thermal', 2.18), ('glasses', 1.65), ('pan', 1.62), ('heat', 1.56), ('stressed', 1.42), ('crack', 1.31), ('break', 1.12), ('cracked', 1.1), ('hot', 0.93), ('created', 0.9), ('factors', 0.81), ('broken', 0.73), ('caused', 0.71), ('too', 0.71), ('damage', 0.69), ('if', 0.68), ('hotter', 0.65), ('governed', 0.61), ('heating', 0.59), ('temperature', 0.59), ('adjacent', 0.59), ('cause', 0.58), ('effect', 0.57), ('fracture', 0.56), ('bradford', 0.55), ('strain', 0.53), ('hammer', 0.51), ('brian', 0.48), ('error', 0.47), ('windows', 0.45), ('will', 0.45), ('reaction', 0.42), ('create', 0.42), ('windshield', 0.41), ('heated', 0.41), ('factor', 0.4), ('cracking', 0.39), ('failure', 0.38), ('mechanical', 0.38), ('when', 0.38), ('formed', 0.38), ('bolt', 0.38), ('mechanism', 0.37), ('warm', 0.37), ('areas', 0.36), ('area', 0.36), ('energy', 0.34), ('disorder', 0.33), ('barry', 0.33), ('shock', 0.32), ('determined', 0.32), ('gage', 0.32), ('sash', 0.31), ('theory', 0.31), ('level', 0.31), ('resistant', 0.31), ('brake', 0.3), ('window', 0.3), ('crash', 0.3), ('hazard', 0.29), ('##ink', 0.27), ('ceramic', 0.27), ('storm', 0.25), ('problem', 0.25), ('issue', 0.24), ('impact', 0.24), ('fridge', 0.24), ('injury', 0.23), ('ross', 0.22), ('causes', 0.22), ('affect', 0.21), ('pressure', 0.21), ('fatigue', 0.21), ('leak', 0.21), ('eye', 0.2), ('frank', 0.2), ('cool', 0.2), ('might', 0.19), ('gravity', 0.18), ('ray', 0.18), ('static', 0.18), ('collapse', 0.18), ('physics', 0.18), ('wave', 0.18), ('reflection', 0.17), ('parker', 0.17), ('strike', 0.17), ('hottest', 0.17), ('burst', 0.16), ('chance', 0.16), ('burn', 0.14), ('rubbing', 0.14), ('interference', 0.14), ('bailey', 0.13), ('vibration', 0.12), ('gilbert', 0.12), ('produced', 0.12), ('rock', 0.12), ('warmer', 0.11), ('get', 0.11), ('drink', 0.11), ('fireplace', 0.11), ('ruin', 0.1), ('brittle', 0.1), ('fragment', 0.1), ('stumble', 0.09), ('formation', 0.09), ('shatter', 0.08), ('great', 0.08), ('friction', 0.08), ('flash', 0.07), ('cracks', 0.07), ('levels', 0.07), ('smash', 0.04), ('fail', 0.04), ('fra', 0.04), ('##glass', 0.03), ('variables', 0.03), ('because', 0.02), ('knock', 0.02), ('sun', 0.02), ('crush', 0.01), ('##e', 0.01), ('anger', 0.01)]
naver/splade - v2 - distil (令牌數量多約100%,FLOPS = 3.82)
number of actual dimensions: 234
SPLADE BOW rep:
[('glass', 2.55), ('stress', 2.39), ('thermal', 2.38), ('glasses', 1.95), ('stressed', 1.87), ('crack', 1.84), ('cool', 1.78), ('heat', 1.62), ('pan', 1.6), ('break', 1.53), ('adjacent', 1.44), ('hotter', 1.43), ('strain', 1.21), ('area', 1.16), ('adjoining', 1.14), ('heated', 1.11), ('window', 1.07), ('stresses', 1.04), ('hot', 1.03), ('created', 1.03), ('create', 1.03), ('cause', 1.02), ('factors', 1.02), ('cooler', 1.01), ('broken', 1.0), ('too', 0.99), ('fracture', 0.96), ('collapse', 0.96), ('cracking', 0.95), ('great', 0.93), ('happen', 0.93), ('windows', 0.89), ('broke', 0.87), ('##e', 0.87), ('pressure', 0.84), ('hottest', 0.84), ('breaking', 0.83), ('govern', 0.79), ('shatter', 0.76), ('level', 0.75), ('heating', 0.69), ('temperature', 0.69), ('cracked', 0.69), ('panel', 0.68), ('##glass', 0.68), ('ceramic', 0.67), ('sash', 0.66), ('warm', 0.66), ('areas', 0.64), ('creating', 0.63), ('will', 0.62), ('tension', 0.61), ('cracks', 0.61), ('optical', 0.6), ('mechanism', 0.58), ('kelly', 0.58), ('determined', 0.58), ('generate', 0.58), ('causes', 0.56), ('if', 0.56), ('factor', 0.56), ('the', 0.56), ('chemical', 0.55), ('governed', 0.55), ('crystal', 0.55), ('strike', 0.55), ('microsoft', 0.54), ('creates', 0.53), ('than', 0.53), ('relation', 0.53), ('glazed', 0.52), ('compression', 0.51), ('painting', 0.51), ('governing', 0.5), ('harden', 0.49), ('solar', 0.48), ('reflection', 0.48), ('ic', 0.46), ('split', 0.45), ('mirror', 0.44), ('damage', 0.43), ('ring', 0.42), ('formation', 0.42), ('wall', 0.41), ('burst', 0.4), ('radiant', 0.4), ('determine', 0.4), ('one', 0.4), ('plastic', 0.39), ('furnace', 0.39), ('difference', 0.39), ('melt', 0.39), ('get', 0.39), ('contract', 0.38), ('forces', 0.38), ('gets', 0.38), ('produce', 0.38), ('surrounding', 0.37), ('vibration', 0.37), ('tile', 0.37), ('fail', 0.36), ('warmer', 0.36), ('rock', 0.35), ('fault', 0.35), ('roof', 0.34), ('burned', 0.34), ('physics', 0.33), ('welding', 0.33), ('why', 0.33), ('a', 0.32), ('pop', 0.32), ('and', 0.31), ('fra', 0.3), ('stat', 0.3), ('withstand', 0.3), ('sunglasses', 0.3), ('material', 0.29), ('ice', 0.29), ('generated', 0.29), ('matter', 0.29), ('frame', 0.28), ('elements', 0.28), ('then', 0.28), ('.', 0.28), ('pont', 0.28), ('blow', 0.28), ('snap', 0.27), ('metal', 0.26), ('effect', 0.26), ('reaction', 0.26), ('related', 0.25), ('aluminium', 0.25), ('neighboring', 0.25), ('weight', 0.25), ('steel', 0.25), ('bulb', 0.25), ('tear', 0.25), ('coating', 0.25), ('plumbing', 0.25), ('co', 0.25), ('microwave', 0.24), ('formed', 0.24), ('pipe', 0.23), ('drink', 0.23), ('chemistry', 0.23), ('energy', 0.22), ('reflect', 0.22), ('dynamic', 0.22), ('leak', 0.22), ('is', 0.22), ('lens', 0.21), ('frost', 0.21), ('lenses', 0.21), ('produced', 0.21), ('induced', 0.2), ('arise', 0.2), ('plate', 0.2), ('equations', 0.19), ('affect', 0.19), ('tired', 0.19), ('mirrors', 0.18), ('thickness', 0.18), ('bending', 0.18), ('cabinet', 0.17), ('apart', 0.17), ('##thermal', 0.17), ('gas', 0.17), ('equation', 0.17), ('relationship', 0.17), ('composition', 0.17), ('engineering', 0.17), ('block', 0.16), ('breaks', 0.16), ('when', 0.16), ('definition', 0.16), ('collapsed', 0.16), ('generation', 0.16), (',', 0.16), ('philips', 0.16), ('later', 0.15), ('wood', 0.15), ('neighbouring', 0.15), ('structural', 0.14), ('regulate', 0.14), ('neighbors', 0.13), ('lighting', 0.13), ('happens', 0.13), ('more', 0.13), ('property', 0.13), ('cooling', 0.12), ('shattering', 0.12), ('melting', 0.12), ('how', 0.11), ('cloud', 0.11), ('barriers', 0.11), ('lam', 0.11), ('conditions', 0.11), ('rule', 0.1), ('insulation', 0.1), ('bathroom', 0.09), ('convection', 0.09), ('cavity', 0.09), ('source', 0.08), ('properties', 0.08), ('bend', 0.08), ('bottles', 0.08), ('ceramics', 0.07), ('temper', 0.07), ('tense', 0.07), ('keller', 0.07), ('breakdown', 0.07), ('concrete', 0.07), ('simon', 0.07), ('solids', 0.06), ('windshield', 0.05), ('eye', 0.05), ('sunlight', 0.05), ('brittle', 0.03), ('caused', 0.03), ('suns', 0.03), ('floor', 0.02), ('components', 0.02), ('photo', 0.02), ('change', 0.02), ('sun', 0.01), ('crystals', 0.01), ('problem', 0.01), ('##proof', 0.01), ('parameters', 0.01), ('gases', 0.0), ('prism', 0.0), ('doing', 0.0), ('lattice', 0.0), ('ground', 0.0)]
- 注意1:此特定段落用作比較示例
4. 如何轉化為實證指標
我們的模型令牌稀疏但有效,這意味著更快的檢索速度(用戶體驗)和更小的索引大小(成本)。以下是在標準MS - MARCO小開發集上的平均檢索時間和縮放後的總FLOPS損失指標:
展開查看更多說明
注意:為什麼選擇Anserini而不是PISA? Anserini是一個基於Lucene的生產就緒庫。常見的工業搜索部署使用基於Lucene的Solr或Elastic,因此性能具有可比性。PISA的延遲對於工業應用無關緊要,因為它只是一個研究系統。完整的Anserini評估日誌包含編碼、索引和查詢的詳細信息。
- BEIR ZST OOD性能:將添加到頁面末尾。
我們的模型在以下方面有所不同
- 共冷凝器權重:與官方最優的SPLADE++或SparseEmbed不同,我們不使用Luyu/co - condenser*模型初始化權重,但仍達到了CoCondenser SPLADE級別的性能。
- 相同大小的模型:官方SPLADE++、SparseEmbed和我們的模型都在相同大小的基礎模型上進行微調,即
bert - base - uncased
。
5. 工業適用性的路線圖和未來方向
- 提高效率:持續改進服務和檢索效率。
- 自定義/領域微調:探索如何在不進行昂貴標註的情況下,在自定義數據集或領域上進行經濟有效的微調。
- 多語言SPLADE:研究如何將模型擴展到多語言環境,解決多語言模型訓練成本高的問題。
6. 使用方法
與流行向量數據庫結合使用
向量數據庫 | Colab鏈接 |
---|---|
Pinecone | [ |
Qdrant | 待確定 |
使用SPLADERunner庫
pip install spladerunner
#One-time init
from spladerunner import Expander
# Default model is the document expander.
exapander = Expander()
#Sample Document expansion
sparse_rep = expander.expand(
["The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science."])
使用HuggingFace
筆記本用戶:先登錄
!huggingface-cli login
集成到代碼中 [如何在代碼中使用HF令牌](https://huggingface.co/docs/hub/en/security - tokens) 進行如下更改:
tokenizer = AutoTokenizer.from_pretrained('prithivida/Splade_PP_en_v1', token=<Your token>)
model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v1', token=<Your token>)
完整代碼
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained('prithivida/Splade_PP_en_v1')
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}
model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v1')
model.to(device)
sentence = """The Manhattan Project and its atomic bomb helped bring an end to World War II. Its legacy of peaceful uses of atomic energy continues to have an impact on history and science."""
inputs = tokenizer(sentence, return_tensors='pt')
inputs = {key: val.to(device) for key, val in inputs.items()}
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
outputs = model(**inputs)
logits, attention_mask = outputs.logits, attention_mask
relu_log = torch.log(1 + torch.relu(logits))
weighted_log = relu_log * attention_mask.unsqueeze(-1)
max_val, _ = torch.max(weighted_log, dim=1)
vector = max_val.squeeze()
cols = vector.nonzero().squeeze().cpu().tolist()
print("number of actual dimensions: ", len(cols))
weights = vector[cols].cpu().tolist()
d = {k: v for k, v in zip(cols, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
bow_rep.append((reverse_voc[k], round(v,2)))
print("SPLADE BOW rep:\n", bow_rep)
BEIR零樣本OOD性能
訓練細節
待確定
🔧 技術細節
模型優化策略
- FLOPS調整:採用與官方SPLADE++不同的序列長度和嚴格受限的FLOPS調度及令牌預算,文檔為128,查詢為24,而非256,靈感來源於Google的SparseEmbed。
- 初始化權重:使用Vanilla
bert - base - uncased
,不像官方SPLADE++ / ColBERT那樣依賴語料庫感知。
性能對比分析
與其他模型相比,在保證檢索效果的同時,通過減少令牌數量和FLOPS,實現了更快的檢索速度和更小的索引大小。
📄 許可證
本項目採用Apache 2.0許可證。
致謝
- 感謝Nils Reimers提供的所有建議。
- 感謝Anserini庫的作者。
侷限性和偏差
BERT模型的所有侷限性和偏差同樣適用於本微調工作。
引用
如果您使用了我們的模型或庫,請進行引用,引用信息如下:
Damodaran, P. (2024). Splade_PP_en_v1: Independent Implementation of SPLADE++ Model (`a.k.a splade - cocondenser* and family`) for the Industry setting. (Version 1.0.0) [Computer software].







