🚀 GLiClass:用于序列分类的通用轻量级模型
GLiClass 是一款高效的零样本分类器,它受到了 GLiNER 工作的启发。该模型在单次前向传播中即可完成分类任务,在具备与交叉编码器相当性能的同时,还拥有更高的计算效率。它可用于 主题分类
、情感分析
,并能在 RAG
管道中作为重排器使用。
🚀 快速开始
安装
首先,你需要安装 GLiClass 库:
pip install gliclass
pip install -U transformers>=4.48.0
初始化模型和管道
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-modern-base-v3.0")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-modern-base-v3.0", add_prefix_space=True)
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
NLI 任务使用方法
如果你想将其用于 NLI 类型的任务,建议将前提表示为文本,假设表示为标签。你可以输入多个假设,但模型在单个输入假设的情况下效果最佳。
text = "The cat slept on the windowsill all afternoon"
labels = ["The cat was awake and playing outside."]
results = pipeline(text, labels, threshold=0.0)[0]
print(results)
✨ 主要特性
- 高效零样本分类:受 GLiNER 工作启发,在单次前向传播中完成分类,计算效率高。
- 多任务应用:可用于主题分类、情感分析,还能在 RAG 管道中作为重排器。
- 逻辑推理能力:模型在逻辑任务上进行训练,诱导推理能力。
- LoRA 微调:使用 LoRA 适配器微调模型,保留先前知识。
📦 安装指南
pip install gliclass
pip install -U transformers>=4.48.0
💻 使用示例
基础用法
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-modern-base-v3.0")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-modern-base-v3.0", add_prefix_space=True)
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
高级用法(NLI 任务)
text = "The cat slept on the windowsill all afternoon"
labels = ["The cat was awake and playing outside."]
results = pipeline(text, labels, threshold=0.0)[0]
print(results)
📚 详细文档
LoRA 参数
|
gliclass‑modern‑base‑v3.0 |
gliclass‑modern‑large‑v3.0 |
gliclass‑base‑v3.0 |
gliclass‑large‑v3.0 |
LoRa r |
512 |
768 |
384 |
384 |
LoRa α |
1024 |
1536 |
768 |
768 |
focal loss α |
0.7 |
0.7 |
0.7 |
0.7 |
Target modules |
"Wqkv", "Wo", "Wi", "linear_1", "linear_2" |
"Wqkv", "Wo", "Wi", "linear_1", "linear_2" |
"query_proj", "key_proj", "value_proj", "dense", "linear_1", "linear_2", mlp.0", "mlp.2", "mlp.4" |
"query_proj", "key_proj", "value_proj", "dense", "linear_1", "linear_2", mlp.0", "mlp.2", "mlp.4" |
GLiClass-V3 模型信息
基准测试
以下是几个文本分类数据集上的 F1 分数。所有测试模型均未在这些数据集上进行微调,而是在零样本设置下进行测试。
GLiClass-V3
数据集 |
gliclass‑large‑v3.0 |
gliclass‑base‑v3.0 |
gliclass‑modern‑large‑v3.0 |
gliclass‑modern‑base‑v3.0 |
gliclass‑edge‑v3.0 |
CR |
0.9398 |
0.9127 |
0.8952 |
0.8902 |
0.8215 |
sst2 |
0.9192 |
0.8959 |
0.9330 |
0.8959 |
0.8199 |
sst5 |
0.4606 |
0.3376 |
0.4619 |
0.2756 |
0.2823 |
20_news_groups |
0.5958 |
0.4759 |
0.3905 |
0.3433 |
0.2217 |
spam |
0.7584 |
0.6760 |
0.5813 |
0.6398 |
0.5623 |
financial_phrasebank |
0.9000 |
0.8971 |
0.5929 |
0.4200 |
0.5004 |
imdb |
0.9366 |
0.9251 |
0.9402 |
0.9158 |
0.8485 |
ag_news |
0.7181 |
0.7279 |
0.7269 |
0.6663 |
0.6645 |
emotion |
0.4506 |
0.4447 |
0.4517 |
0.4254 |
0.3851 |
cap_sotu |
0.4589 |
0.4614 |
0.4072 |
0.3625 |
0.2583 |
rotten_tomatoes |
0.8411 |
0.7943 |
0.7664 |
0.7070 |
0.7024 |
massive |
0.5649 |
0.5040 |
0.3905 |
0.3442 |
0.2414 |
banking |
0.5574 |
0.4698 |
0.3683 |
0.3561 |
0.0272 |
平均 |
0.7001 |
0.6556 |
0.6082 |
0.5571 |
0.4873 |
先前的 GLiClass 模型
数据集 |
gliclass‑large‑v1.0‑lw |
gliclass‑base‑v1.0‑lw |
gliclass‑modern‑large‑v2.0 |
gliclass‑modern‑base‑v2.0 |
CR |
0.9226 |
0.9097 |
0.9154 |
0.8977 |
sst2 |
0.9247 |
0.8987 |
0.9308 |
0.8524 |
sst5 |
0.2891 |
0.3779 |
0.2152 |
0.2346 |
20_news_groups |
0.4083 |
0.3953 |
0.3813 |
0.3857 |
spam |
0.3642 |
0.5126 |
0.6603 |
0.4608 |
financial_phrasebank |
0.9044 |
0.8880 |
0.3152 |
0.3465 |
imdb |
0.9429 |
0.9351 |
0.9449 |
0.9188 |
ag_news |
0.7559 |
0.6985 |
0.6999 |
0.6836 |
emotion |
0.3951 |
0.3516 |
0.4341 |
0.3926 |
cap_sotu |
0.4749 |
0.4643 |
0.4095 |
0.3588 |
rotten_tomatoes |
0.8807 |
0.8429 |
0.7386 |
0.6066 |
massive |
0.5606 |
0.4635 |
0.2394 |
0.3458 |
banking |
0.3317 |
0.4396 |
0.1355 |
0.2907 |
平均 |
0.6273 |
0.6291 |
0.5400 |
0.5211 |
交叉编码器
数据集 |
deberta‑v3‑large‑zeroshot‑v2.0 |
deberta‑v3‑base‑zeroshot‑v2.0 |
roberta‑large‑zeroshot‑v2.0‑c |
comprehend_it‑base |
CR |
0.9134 |
0.9051 |
0.9141 |
0.8936 |
sst2 |
0.9272 |
0.9176 |
0.8573 |
0.9006 |
sst5 |
0.3861 |
0.3848 |
0.4159 |
0.4140 |
enron_spam |
0.5970 |
0.4640 |
0.5040 |
0.3637 |
financial_phrasebank |
0.5820 |
0.6690 |
0.4550 |
0.4695 |
imdb |
0.9180 |
0.8990 |
0.9040 |
0.4644 |
ag_news |
0.7710 |
0.7420 |
0.7450 |
0.6016 |
emotion |
0.4840 |
0.4950 |
0.4860 |
0.4165 |
cap_sotu |
0.5020 |
0.4770 |
0.5230 |
0.3823 |
rotten_tomatoes |
0.8680 |
0.8600 |
0.8410 |
0.4728 |
massive |
0.5180 |
0.5200 |
0.5200 |
0.3314 |
banking77 |
0.5670 |
0.4460 |
0.2900 |
0.4972 |
平均 |
0.6695 |
0.6483 |
0.6213 |
0.5173 |
推理速度
每个模型都在文本长度为 64、256 和 512 个标记,标签数量为 1、2、4、8、16、32、64 和 128 的示例上进行了测试,然后对不同文本长度的分数进行了平均。
📄 许可证
本项目采用 Apache-2.0 许可证。
模型相关信息
属性 |
详情 |
模型类型 |
通用轻量级序列分类模型 |
训练数据 |
BioMike/formal-logic-reasoning-gliclass-2k、knowledgator/gliclass-v3-logic-dataset、tau/commonsense_qa |
评估指标 |
F1 |
标签 |
text classification、nli、sentiment analysis |
管道标签 |
text-classification |
