模型概述
模型特點
模型能力
使用案例
🚀 stella模型
stella是一個通用的文本編碼模型,可將文本轉換為向量表示,用於檢索、聚類、分類等自然語言處理任務。該模型主要有以下幾種:
模型名稱 | 模型大小 (GB) | 向量維度 | 序列長度 | 支持語言 | 檢索時是否需要指令? |
---|---|---|---|---|---|
stella-base-en-v2 | 0.2 | 768 | 512 | 英文 | 否 |
stella-large-zh-v2 | 0.65 | 1024 | 1024 | 中文 | 否 |
stella-base-zh-v2 | 0.2 | 768 | 1024 | 中文 | 否 |
stella-large-zh | 0.65 | 1024 | 1024 | 中文 | 是 |
stella-base-zh | 0.2 | 768 | 1024 | 中文 | 是 |
完整的訓練思路和訓練過程已記錄在博客1和博客2,歡迎閱讀討論。
✨ 主要特性
- 多語言支持:涵蓋中文和英文,適用於不同語言場景。
- 使用便捷:部分模型無需額外的前綴文本,降低使用門檻。
- 長文本處理:具備一定的長文本編碼能力,通過特殊處理規則優化效果。
- 持續更新:不斷推出新的版本和模型,提升性能和功能。
📦 安裝指南
在使用stella模型前,你需要安裝相關依賴庫。可以使用以下命令進行安裝:
pip install sentence-transformers transformers sklearn torch numpy mteb
💻 使用示例
stella 中文系列模型
在sentence-transformer庫中的使用方法:
from sentence_transformers import SentenceTransformer
sentences = ["數據1", "數據2"]
model = SentenceTransformer('infgrad/stella-base-zh-v2')
print(model.max_seq_length)
embeddings_1 = model.encode(sentences, normalize_embeddings=True)
embeddings_2 = model.encode(sentences, normalize_embeddings=True)
similarity = embeddings_1 @ embeddings_2.T
print(similarity)
直接使用transformers庫:
from transformers import AutoModel, AutoTokenizer
from sklearn.preprocessing import normalize
model = AutoModel.from_pretrained('infgrad/stella-base-zh-v2')
tokenizer = AutoTokenizer.from_pretrained('infgrad/stella-base-zh-v2')
sentences = ["數據1", "數據ABCDEFGH"]
batch_data = tokenizer(
batch_text_or_text_pairs=sentences,
padding="longest",
return_tensors="pt",
max_length=1024,
truncation=True,
)
attention_mask = batch_data["attention_mask"]
model_output = model(**batch_data)
last_hidden = model_output.last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
vectors = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
vectors = normalize(vectors, norm="l2", axis=1, )
print(vectors.shape) # 2,768
stella 英文系列模型
使用Sentence-Transformers:
from sentence_transformers import SentenceTransformer
sentences = ["one car come", "one car go"]
model = SentenceTransformer('infgrad/stella-base-en-v2')
print(model.max_seq_length)
embeddings_1 = model.encode(sentences, normalize_embeddings=True)
embeddings_2 = model.encode(sentences, normalize_embeddings=True)
similarity = embeddings_1 @ embeddings_2.T
print(similarity)
使用HuggingFace Transformers:
from transformers import AutoModel, AutoTokenizer
from sklearn.preprocessing import normalize
model = AutoModel.from_pretrained('infgrad/stella-base-en-v2')
tokenizer = AutoTokenizer.from_pretrained('infgrad/stella-base-en-v2')
sentences = ["one car come", "one car go"]
batch_data = tokenizer(
batch_text_or_text_pairs=sentences,
padding="longest",
return_tensors="pt",
max_length=512,
truncation=True,
)
attention_mask = batch_data["attention_mask"]
model_output = model(**batch_data)
last_hidden = model_output.last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
vectors = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
vectors = normalize(vectors, norm="l2", axis=1, )
print(vectors.shape) # 2,768
📚 詳細文檔
訓練數據
- 開源數據:選用了wudao_base_200GB[1]、m3e[2]和simclue[3]等開源數據,並著重挑選了長度大於512的文本。
- 構造數據:在通用語料庫上使用LLM構造了一批(question, paragraph)和(sentence, paragraph)數據。
訓練方法
- 對比學習損失函數
- 帶有難負例的對比學習損失函數:分別基於bm25和vector構造了難負例。
- EWC(Elastic Weights Consolidation)[4]
- cosent loss[5]
- 迭代更新:每一種類型的數據一個迭代器,分別計算loss進行更新。
初始權重
stella-base-zh和stella-large-zh分別以piccolo-base-zh[6]和piccolo-large-zh作為基礎模型,512 - 1024的position embedding使用層次分解位置編碼[7]進行初始化。感謝商湯科技研究院開源的piccolo系列模型。
stella-v2改進
stella-v2在stella模型的基礎上,使用了更多的訓練數據,同時通過知識蒸餾等方法去除了前置的instruction(比如piccolo的查詢:
, 結果:
, e5的query:
和passage:
)。
🔧 技術細節
硬件
單卡A100 - 80GB
環境
torch1.13.*; transformers-trainer + deepspeed + gradient-checkpointing
學習率
1e - 6
batch_size
- base模型為1024,額外增加20%的難負例。
- large模型為768,額外增加20%的難負例。
數據量
- 第一版模型約100萬,其中用LLM構造的數據約有200K。LLM模型大小為13b。
- v2系列模型到了2000萬訓練數據。
📄 許可證
本項目採用MIT許可證。
📈 評測指標
C-MTEB leaderboard (Chinese)
模型名稱 | 模型大小 (GB) | 向量維度 | 序列長度 | 平均得分(35) | 分類得分(9) | 聚類得分(4) | 成對分類得分(2) | 重排得分(4) | 檢索得分(8) | STS得分(8) |
---|---|---|---|---|---|---|---|---|---|---|
stella-large-zh-v2 | 0.65 | 1024 | 1024 | 65.13 | 69.05 | 49.16 | 82.68 | 66.41 | 70.14 | 58.66 |
stella-base-zh-v2 | 0.2 | 768 | 1024 | 64.36 | 68.29 | 49.4 | 79.95 | 66.1 | 70.08 | 56.92 |
stella-large-zh | 0.65 | 1024 | 1024 | 64.54 | 67.62 | 48.65 | 78.72 | 65.98 | 71.02 | 58.3 |
stella-base-zh | 0.2 | 768 | 1024 | 64.16 | 67.77 | 48.7 | 76.09 | 66.95 | 71.07 | 56.54 |
MTEB leaderboard (English)
模型名稱 | 模型大小 (GB) | 向量維度 | 序列長度 | 平均得分(56) | 分類得分(12) | 聚類得分(11) | 成對分類得分(3) | 重排得分(4) | 檢索得分(15) | STS得分(10) | 摘要得分(1) |
---|---|---|---|---|---|---|---|---|---|---|---|
stella-base-en-v2 | 0.2 | 768 | 512 | 62.61 | 75.28 | 44.9 | 86.45 | 58.77 | 50.1 | 83.02 | 32.52 |
復現結果
C-MTEB:
import torch
import numpy as np
from typing import List
from mteb import MTEB
from sentence_transformers import SentenceTransformer
class FastTextEncoder():
def __init__(self, model_name):
self.model = SentenceTransformer(model_name).cuda().half().eval()
self.model.max_seq_length = 512
def encode(
self,
input_texts: List[str],
*args,
**kwargs
):
new_sens = list(set(input_texts))
new_sens.sort(key=lambda x: len(x), reverse=True)
vecs = self.model.encode(
new_sens, normalize_embeddings=True, convert_to_numpy=True, batch_size=256
).astype(np.float32)
sen2arrid = {sen: idx for idx, sen in enumerate(new_sens)}
vecs = vecs[[sen2arrid[sen] for sen in input_texts]]
torch.cuda.empty_cache()
return vecs
if __name__ == '__main__':
model_name = "infgrad/stella-base-zh-v2"
output_folder = "zh_mteb_results/stella-base-zh-v2"
task_names = [t.description["name"] for t in MTEB(task_langs=['zh', 'zh-CN']).tasks]
model = FastTextEncoder(model_name)
for task in task_names:
MTEB(tasks=[task], task_langs=['zh', 'zh-CN']).run(model, output_folder=output_folder)
MTEB: 你可以使用官方腳本復現結果。scripts/run_mteb_english.py
長文本評測
現有數據集在評估模型長文本編碼能力方面存在問題,如長度大於512的文本過少,且多數情況下檢索只需前512的文本內容。為解決此問題,蒐集整理了6份長文本測試集:
- CMRC2018,通用百科
- CAIL,法律閱讀理解
- DRCD,繁體百科,已轉簡體
- Military,軍工問答
- Squad,英文閱讀理解,已轉中文
- Multifieldqa_zh,清華的大模型長文本理解能力評測數據[9]
處理規則是選取答案在512長度之後的文本,短的測試數據會欠採樣,長短文本佔比約為1:2。除Military數據集外,其他5個測試數據的下載地址:https://drive.google.com/file/d/1WC6EWaCbVgz-vPMDFH4TwAMkLyh5WNcN/view?usp=sharing
評測指標為Recall@5,結果如下:
數據集 | piccolo-base-zh | piccolo-large-zh | bge-base-zh | bge-large-zh | stella-base-zh | stella-large-zh |
---|---|---|---|---|---|---|
CMRC2018 | 94.34 | 93.82 | 91.56 | 93.12 | 96.08 | 95.56 |
CAIL | 28.04 | 33.64 | 31.22 | 33.94 | 34.62 | 37.18 |
DRCD | 78.25 | 77.9 | 78.34 | 80.26 | 86.14 | 84.58 |
Military | 76.61 | 73.06 | 75.65 | 75.81 | 83.71 | 80.48 |
Squad | 91.21 | 86.61 | 87.87 | 90.38 | 93.31 | 91.21 |
Multifieldqa_zh | 81.41 | 83.92 | 83.92 | 83.42 | 79.9 | 80.4 |
Average | 74.98 | 74.83 | 74.76 | 76.15 | 78.96 | 78.24 |
📋 ToDoList
- 評測的穩定性:評測過程中Clustering任務與官方結果有±0.0x的小差距,因聚類代碼未設置random_seed,差距可忽略。
- 更高質量的長文本訓練和測試數據:訓練數據存在噪聲,測試數據問題類型單一,不符合真實分佈。
- OOD的性能:在非通用領域,包括stella在內的眾多模型效果不如BM25。
📖 參考資料
- https://www.scidb.cn/en/detail?dataSetId=c6a3fe684227415a9db8e21bac4a15ab
- https://github.com/wangyuxinwhy/uniem
- https://github.com/CLUEbenchmark/SimCLUE
- https://arxiv.org/abs/1612.00796
- https://kexue.fm/archives/8847
- https://huggingface.co/sensenova/piccolo-base-zh
- https://kexue.fm/archives/7947
- https://github.com/FlagOpen/FlagEmbedding
- https://github.com/THUDM/LongBench
🌟 新聞動態
- [2024-04-06] 開源puff系列模型,專門針對檢索和語義匹配任務,更多考慮泛化性和私有通用測試集效果,向量維度可變,支持中英雙語。
- [2024-02-27] 開源stella-mrl-large-zh-v3.5-1792d模型,支持向量可變維度。
- [2024-02-17] 開源stella v3系列、dialogue編碼模型和相關訓練數據。
- [2023-10-19] 開源stella-base-en-v2,使用簡單,無需任何前綴文本。
- [2023-10-12] 開源stella-base-zh-v2和stella-large-zh-v2,效果更好且使用簡單,無需任何前綴文本。
- [2023-09-11] 開源stella-base-zh和stella-large-zh
歡迎去本人主頁查看最新模型,並提出您的寶貴意見!
⚠️ 重要提示
因為長文本評測數據數量稀少,所以構造時也使用了train部分,如果自行評測,請注意模型的訓練數據以免數據洩露。







