🚀 GLiClass:用於序列分類的通用輕量級模型
GLiClass 是一款高效的零樣本分類器,它受到了 GLiNER 工作的啟發。該模型在單次前向傳播中即可完成分類任務,在具備與交叉編碼器相當性能的同時,還擁有更高的計算效率。它可用於 主題分類
、情感分析
,並能在 RAG
管道中作為重排器使用。
🚀 快速開始
安裝
首先,你需要安裝 GLiClass 庫:
pip install gliclass
pip install -U transformers>=4.48.0
初始化模型和管道
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-modern-base-v3.0")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-modern-base-v3.0", add_prefix_space=True)
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
NLI 任務使用方法
如果你想將其用於 NLI 類型的任務,建議將前提表示為文本,假設表示為標籤。你可以輸入多個假設,但模型在單個輸入假設的情況下效果最佳。
text = "The cat slept on the windowsill all afternoon"
labels = ["The cat was awake and playing outside."]
results = pipeline(text, labels, threshold=0.0)[0]
print(results)
✨ 主要特性
- 高效零樣本分類:受 GLiNER 工作啟發,在單次前向傳播中完成分類,計算效率高。
- 多任務應用:可用於主題分類、情感分析,還能在 RAG 管道中作為重排器。
- 邏輯推理能力:模型在邏輯任務上進行訓練,誘導推理能力。
- LoRA 微調:使用 LoRA 適配器微調模型,保留先前知識。
📦 安裝指南
pip install gliclass
pip install -U transformers>=4.48.0
💻 使用示例
基礎用法
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-modern-base-v3.0")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-modern-base-v3.0", add_prefix_space=True)
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
高級用法(NLI 任務)
text = "The cat slept on the windowsill all afternoon"
labels = ["The cat was awake and playing outside."]
results = pipeline(text, labels, threshold=0.0)[0]
print(results)
📚 詳細文檔
LoRA 參數
|
gliclass‑modern‑base‑v3.0 |
gliclass‑modern‑large‑v3.0 |
gliclass‑base‑v3.0 |
gliclass‑large‑v3.0 |
LoRa r |
512 |
768 |
384 |
384 |
LoRa α |
1024 |
1536 |
768 |
768 |
focal loss α |
0.7 |
0.7 |
0.7 |
0.7 |
Target modules |
"Wqkv", "Wo", "Wi", "linear_1", "linear_2" |
"Wqkv", "Wo", "Wi", "linear_1", "linear_2" |
"query_proj", "key_proj", "value_proj", "dense", "linear_1", "linear_2", mlp.0", "mlp.2", "mlp.4" |
"query_proj", "key_proj", "value_proj", "dense", "linear_1", "linear_2", mlp.0", "mlp.2", "mlp.4" |
GLiClass-V3 模型信息
基準測試
以下是幾個文本分類數據集上的 F1 分數。所有測試模型均未在這些數據集上進行微調,而是在零樣本設置下進行測試。
GLiClass-V3
數據集 |
gliclass‑large‑v3.0 |
gliclass‑base‑v3.0 |
gliclass‑modern‑large‑v3.0 |
gliclass‑modern‑base‑v3.0 |
gliclass‑edge‑v3.0 |
CR |
0.9398 |
0.9127 |
0.8952 |
0.8902 |
0.8215 |
sst2 |
0.9192 |
0.8959 |
0.9330 |
0.8959 |
0.8199 |
sst5 |
0.4606 |
0.3376 |
0.4619 |
0.2756 |
0.2823 |
20_news_groups |
0.5958 |
0.4759 |
0.3905 |
0.3433 |
0.2217 |
spam |
0.7584 |
0.6760 |
0.5813 |
0.6398 |
0.5623 |
financial_phrasebank |
0.9000 |
0.8971 |
0.5929 |
0.4200 |
0.5004 |
imdb |
0.9366 |
0.9251 |
0.9402 |
0.9158 |
0.8485 |
ag_news |
0.7181 |
0.7279 |
0.7269 |
0.6663 |
0.6645 |
emotion |
0.4506 |
0.4447 |
0.4517 |
0.4254 |
0.3851 |
cap_sotu |
0.4589 |
0.4614 |
0.4072 |
0.3625 |
0.2583 |
rotten_tomatoes |
0.8411 |
0.7943 |
0.7664 |
0.7070 |
0.7024 |
massive |
0.5649 |
0.5040 |
0.3905 |
0.3442 |
0.2414 |
banking |
0.5574 |
0.4698 |
0.3683 |
0.3561 |
0.0272 |
平均 |
0.7001 |
0.6556 |
0.6082 |
0.5571 |
0.4873 |
先前的 GLiClass 模型
數據集 |
gliclass‑large‑v1.0‑lw |
gliclass‑base‑v1.0‑lw |
gliclass‑modern‑large‑v2.0 |
gliclass‑modern‑base‑v2.0 |
CR |
0.9226 |
0.9097 |
0.9154 |
0.8977 |
sst2 |
0.9247 |
0.8987 |
0.9308 |
0.8524 |
sst5 |
0.2891 |
0.3779 |
0.2152 |
0.2346 |
20_news_groups |
0.4083 |
0.3953 |
0.3813 |
0.3857 |
spam |
0.3642 |
0.5126 |
0.6603 |
0.4608 |
financial_phrasebank |
0.9044 |
0.8880 |
0.3152 |
0.3465 |
imdb |
0.9429 |
0.9351 |
0.9449 |
0.9188 |
ag_news |
0.7559 |
0.6985 |
0.6999 |
0.6836 |
emotion |
0.3951 |
0.3516 |
0.4341 |
0.3926 |
cap_sotu |
0.4749 |
0.4643 |
0.4095 |
0.3588 |
rotten_tomatoes |
0.8807 |
0.8429 |
0.7386 |
0.6066 |
massive |
0.5606 |
0.4635 |
0.2394 |
0.3458 |
banking |
0.3317 |
0.4396 |
0.1355 |
0.2907 |
平均 |
0.6273 |
0.6291 |
0.5400 |
0.5211 |
交叉編碼器
數據集 |
deberta‑v3‑large‑zeroshot‑v2.0 |
deberta‑v3‑base‑zeroshot‑v2.0 |
roberta‑large‑zeroshot‑v2.0‑c |
comprehend_it‑base |
CR |
0.9134 |
0.9051 |
0.9141 |
0.8936 |
sst2 |
0.9272 |
0.9176 |
0.8573 |
0.9006 |
sst5 |
0.3861 |
0.3848 |
0.4159 |
0.4140 |
enron_spam |
0.5970 |
0.4640 |
0.5040 |
0.3637 |
financial_phrasebank |
0.5820 |
0.6690 |
0.4550 |
0.4695 |
imdb |
0.9180 |
0.8990 |
0.9040 |
0.4644 |
ag_news |
0.7710 |
0.7420 |
0.7450 |
0.6016 |
emotion |
0.4840 |
0.4950 |
0.4860 |
0.4165 |
cap_sotu |
0.5020 |
0.4770 |
0.5230 |
0.3823 |
rotten_tomatoes |
0.8680 |
0.8600 |
0.8410 |
0.4728 |
massive |
0.5180 |
0.5200 |
0.5200 |
0.3314 |
banking77 |
0.5670 |
0.4460 |
0.2900 |
0.4972 |
平均 |
0.6695 |
0.6483 |
0.6213 |
0.5173 |
推理速度
每個模型都在文本長度為 64、256 和 512 個標記,標籤數量為 1、2、4、8、16、32、64 和 128 的示例上進行了測試,然後對不同文本長度的分數進行了平均。
📄 許可證
本項目採用 Apache-2.0 許可證。
模型相關信息
屬性 |
詳情 |
模型類型 |
通用輕量級序列分類模型 |
訓練數據 |
BioMike/formal-logic-reasoning-gliclass-2k、knowledgator/gliclass-v3-logic-dataset、tau/commonsense_qa |
評估指標 |
F1 |
標籤 |
text classification、nli、sentiment analysis |
管道標籤 |
text-classification |
