🚀 ⭐ GLiClass:用於序列分類的通用輕量級模型
GLiClass 是一個高效的零樣本分類器,受 GLiNER 工作的啟發而開發。它在分類性能上與交叉編碼器相當,但計算效率更高,因為它只需一次前向傳播即可完成分類。
該模型可用於 主題分類
、情感分析
,還能在 RAG
管道中作為重排器使用。模型基於合成數據和可商用的許可數據進行訓練,因此可應用於商業場景。其骨幹模型為 mdeberta-v3-base,支持多語言理解,非常適合處理不同語言的文本任務。
🚀 快速開始
安裝 GLiClass 庫
首先,你需要安裝 GLiClass 庫:
pip install gliclass
pip install -U transformers>=4.48.0
初始化模型和管道
以下是不同語言的使用示例:
英語
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-x-base")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-x-base", add_prefix_space=True)
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
西班牙語
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-x-base")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-x-base", add_prefix_space=True)
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "¡Un día veré el mundo!"
labels = ["viajes", "sueños", "deportes", "ciencia", "política"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
意大利語
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-x-base")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-x-base", add_prefix_space=True)
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "Un giorno vedrò il mondo!"
labels = ["viaggi", "sogni", "sport", "scienza", "politica"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
法語
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-x-base")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-x-base", add_prefix_space=True)
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "Un jour, je verrai le monde!"
labels = ["voyage", "rêves", "sport", "science", "politique"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
德語
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-x-base")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-x-base", add_prefix_space=True)
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "Eines Tages werde ich die Welt sehen!"
labels = ["Reisen", "Träume", "Sport", "Wissenschaft", "Politik"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
📊 基準測試
以下是該模型在幾個文本分類數據集上的 F1 分數。所有測試模型均未在這些數據集上進行微調,而是在零樣本設置下進行測試。
多語言基準測試
數據集 |
gliclass-x-base |
gliclass-base-v3.0 |
gliclass-large-v3.0 |
FredZhang7/toxi-text-3M |
0.5972 |
0.5072 |
0.6118 |
SetFit/xglue_nc |
0.5014 |
0.5348 |
0.5378 |
Davlan/sib200_14classes |
0.4663 |
0.2867 |
0.3173 |
uhhlt/GermEval2017 |
0.3999 |
0.4010 |
0.4299 |
dolfsai/toxic_es |
0.1250 |
0.1399 |
0.1412 |
平均 |
0.41796 |
0.37392 |
0.4076 |
通用基準測試
數據集 |
gliclass-x-base |
gliclass-base-v3.0 |
gliclass-large-v3.0 |
SetFit/CR |
0.8630 |
0.9127 |
0.9398 |
SetFit/sst2 |
0.8554 |
0.8959 |
0.9192 |
SetFit/sst5 |
0.3287 |
0.3376 |
0.4606 |
AmazonScience/massive |
0.2611 |
0.5040 |
0.5649 |
stanfordnlp/imdb |
0.8840 |
0.9251 |
0.9366 |
SetFit/20_newsgroups |
0.4116 |
0.4759 |
0.5958 |
SetFit/enron_spam |
0.5929 |
0.6760 |
0.7584 |
PolyAI/banking77 |
0.3098 |
0.4698 |
0.5574 |
takala/financial_phrasebank |
0.7851 |
0.8971 |
0.9000 |
ag_news |
0.6815 |
0.7279 |
0.7181 |
dair-ai/emotion |
0.3667 |
0.4447 |
0.4506 |
MoritzLaurer/cap_sotu |
0.3935 |
0.4614 |
0.4589 |
cornell/rotten_tomatoes |
0.7252 |
0.7943 |
0.8411 |
平均 |
0.5737 |
0.6556 |
0.7001 |
📄 許可證
本項目採用 Apache-2.0 許可證。