🚀 GLiClass: シーケンス分類のための汎用的で軽量なモデル
これは、GLiNERの研究にインスパイアされた効率的なゼロショット分類器です。分類が1回の順伝播で行われるため、クロスエンコーダと同等の性能を示しながら、より計算効率が高いです。
このモデルは、トピック分類
、センチメント分析
、およびRAG
パイプラインの再ランキングに使用できます。
モデルは、論理的なタスクで訓練され、推論能力を誘発しました。LoRAアダプタを使用して、以前の知識を破壊することなくモデルを微調整しました。
🚀 クイックスタート
インストール
まず、GLiClassライブラリをインストールする必要があります。
pip install gliclass
pip install -U transformers>=4.48.0
モデルとパイプラインの初期化
次に、モデルとパイプラインを初期化する必要があります。
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-modern-base-v3.0")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-modern-base-v3.0", add_prefix_space=True)
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
NLIタイプのタスクでの使用
NLIタイプのタスクで使用する場合は、前提をテキストとして、仮説をラベルとして表現することをおすすめします。複数の仮説を入力することもできますが、モデルは単一の入力仮説で最適に動作します。
text = "The cat slept on the windowsill all afternoon"
labels = ["The cat was awake and playing outside."]
results = pipeline(text, labels, threshold=0.0)[0]
print(results)
✨ 主な機能
- ゼロショット分類が可能で、事前学習なしで分類タスクを実行できます。
- クロスエンコーダと同等の性能を維持しながら、計算効率が高いです。
- トピック分類、センチメント分析、RAGパイプラインの再ランキングなど、さまざまなタスクに使用できます。
📦 インストール
pip install gliclass
pip install -U transformers>=4.48.0
💻 使用例
基本的な使用法
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-modern-base-v3.0")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-modern-base-v3.0", add_prefix_space=True)
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
高度な使用法
text = "The cat slept on the windowsill all afternoon"
labels = ["The cat was awake and playing outside."]
results = pipeline(text, labels, threshold=0.0)[0]
print(results)
📚 ドキュメント
LoRAパラメータ
|
gliclass‑modern‑base‑v3.0 |
gliclass‑modern‑large‑v3.0 |
gliclass‑base‑v3.0 |
gliclass‑large‑v3.0 |
LoRa r |
512 |
768 |
384 |
384 |
LoRa α |
1024 |
1536 |
768 |
768 |
focal loss α |
0.7 |
0.7 |
0.7 |
0.7 |
Target modules |
"Wqkv", "Wo", "Wi", "linear_1", "linear_2" |
"Wqkv", "Wo", "Wi", "linear_1", "linear_2" |
"query_proj", "key_proj", "value_proj", "dense", "linear_1", "linear_2", mlp.0", "mlp.2", "mlp.4" |
"query_proj", "key_proj", "value_proj", "dense", "linear_1", "linear_2", mlp.0", "mlp.2", "mlp.4" |
GLiClass-V3モデル
ベンチマーク
以下は、いくつかのテキスト分類データセットでのF1スコアです。すべてのテスト対象モデルは、これらのデータセットで微調整されておらず、ゼロショット設定でテストされています。
GLiClass-V3
データセット |
gliclass‑large‑v3.0 |
gliclass‑base‑v3.0 |
gliclass‑modern‑large‑v3.0 |
gliclass‑modern‑base‑v3.0 |
gliclass‑edge‑v3.0 |
CR |
0.9398 |
0.9127 |
0.8952 |
0.8902 |
0.8215 |
sst2 |
0.9192 |
0.8959 |
0.9330 |
0.8959 |
0.8199 |
sst5 |
0.4606 |
0.3376 |
0.4619 |
0.2756 |
0.2823 |
20_news_ groups |
0.5958 |
0.4759 |
0.3905 |
0.3433 |
0.2217 |
spam |
0.7584 |
0.6760 |
0.5813 |
0.6398 |
0.5623 |
financial_ phrasebank |
0.9000 |
0.8971 |
0.5929 |
0.4200 |
0.5004 |
imdb |
0.9366 |
0.9251 |
0.9402 |
0.9158 |
0.8485 |
ag_news |
0.7181 |
0.7279 |
0.7269 |
0.6663 |
0.6645 |
emotion |
0.4506 |
0.4447 |
0.4517 |
0.4254 |
0.3851 |
cap_sotu |
0.4589 |
0.4614 |
0.4072 |
0.3625 |
0.2583 |
rotten_ tomatoes |
0.8411 |
0.7943 |
0.7664 |
0.7070 |
0.7024 |
massive |
0.5649 |
0.5040 |
0.3905 |
0.3442 |
0.2414 |
banking |
0.5574 |
0.4698 |
0.3683 |
0.3561 |
0.0272 |
AVERAGE |
0.7001 |
0.6556 |
0.6082 |
0.5571 |
0.4873 |
以前のGLiClassモデル
データセット |
gliclass‑large‑v1.0‑lw |
gliclass‑base‑v1.0‑lw |
gliclass‑modern‑large‑v2.0 |
gliclass‑modern‑base‑v2.0 |
CR |
0.9226 |
0.9097 |
0.9154 |
0.8977 |
sst2 |
0.9247 |
0.8987 |
0.9308 |
0.8524 |
sst5 |
0.2891 |
0.3779 |
0.2152 |
0.2346 |
20_news_ groups |
0.4083 |
0.3953 |
0.3813 |
0.3857 |
spam |
0.3642 |
0.5126 |
0.6603 |
0.4608 |
financial_ phrasebank |
0.9044 |
0.8880 |
0.3152 |
0.3465 |
imdb |
0.9429 |
0.9351 |
0.9449 |
0.9188 |
ag_news |
0.7559 |
0.6985 |
0.6999 |
0.6836 |
emotion |
0.3951 |
0.3516 |
0.4341 |
0.3926 |
cap_sotu |
0.4749 |
0.4643 |
0.4095 |
0.3588 |
rotten_ tomatoes |
0.8807 |
0.8429 |
0.7386 |
0.6066 |
massive |
0.5606 |
0.4635 |
0.2394 |
0.3458 |
banking |
0.3317 |
0.4396 |
0.1355 |
0.2907 |
AVERAGE |
0.6273 |
0.6291 |
0.5400 |
0.5211 |
クロスエンコーダ
データセット |
deberta‑v3‑large‑zeroshot‑v2.0 |
deberta‑v3‑base‑zeroshot‑v2.0 |
roberta‑large‑zeroshot‑v2.0‑c |
comprehend_it‑base |
CR |
0.9134 |
0.9051 |
0.9141 |
0.8936 |
sst2 |
0.9272 |
0.9176 |
0.8573 |
0.9006 |
sst5 |
0.3861 |
0.3848 |
0.4159 |
0.4140 |
enron_ spam |
0.5970 |
0.4640 |
0.5040 |
0.3637 |
financial_ phrasebank |
0.5820 |
0.6690 |
0.4550 |
0.4695 |
imdb |
0.9180 |
0.8990 |
0.9040 |
0.4644 |
ag_news |
0.7710 |
0.7420 |
0.7450 |
0.6016 |
emotion |
0.4840 |
0.4950 |
0.4860 |
0.4165 |
cap_sotu |
0.5020 |
0.4770 |
0.5230 |
0.3823 |
rotten_ tomatoes |
0.8680 |
0.8600 |
0.8410 |
0.4728 |
massive |
0.5180 |
0.5200 |
0.5200 |
0.3314 |
banking77 |
0.5670 |
0.4460 |
0.2900 |
0.4972 |
AVERAGE |
0.6695 |
0.6483 |
0.6213 |
0.5173 |
推論速度
各モデルは、テキスト中に64、256、および512トークンを含むサンプルと、1、2、4、8、16、32、64、および128のラベルで、a6000 GPU上でテストされました。その後、テキスト長にわたってスコアが平均化されました。
📄 ライセンス
このプロジェクトは、Apache-2.0ライセンスの下でライセンスされています。