🚀 ⭐ GLiClass: シーケンス分類のための汎用的で軽量なモデル
これは、GLiNER の研究に触発された効率的なゼロショット分類器です。分類が単一の順伝播で行われるため、クロスエンコーダと同じ性能を示しながら、より計算効率が高いです。
このモデルは、トピック分類
、感情分析
、および RAG
パイプラインのリランカーとして使用できます。
このモデルは合成データで学習されており、商用アプリケーションで使用することができます。
🚀 クイックスタート
使い方:
まず、GLiClass ライブラリをインストールする必要があります。
pip install gliclass
次に、モデルとパイプラインを初期化する必要があります。
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-small-v1.0")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-small-v1.0")
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
✨ 主な機能
- ゼロショット分類が可能で、事前学習なしで分類タスクを実行できます。
- トピック分類、感情分析、RAG パイプラインのリランカーとして利用できます。
- 合成データで学習されており、商用利用可能です。
📦 インストール
pip install gliclass
💻 使用例
基本的な使用法
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
model = GLiClassModel.from_pretrained("knowledgator/gliclass-small-v1.0")
tokenizer = AutoTokenizer.from_pretrained("knowledgator/gliclass-small-v1.0")
pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device='cuda:0')
text = "One day I will see the world!"
labels = ["travel", "dreams", "sport", "science", "politics"]
results = pipeline(text, labels, threshold=0.5)[0]
for result in results:
print(result["label"], "=>", result["score"])
📚 ドキュメント
ベンチマーク:
以下に、いくつかのテキスト分類データセットでの F1 スコアを示します。すべてのテスト対象モデルは、これらのデータセットでファインチューニングされておらず、ゼロショット設定でテストされました。
以下に、他の GLiClass モデルとの比較を示します。
データセット |
gliclass-small-v1.0-lw |
gliclass-base-v1.0-lw |
gliclass-large-v1.0-lw |
gliclass-small-v1.0 |
gliclass-base-v1.0 |
gliclass-large-v1.0 |
CR |
0.8886 |
0.9097 |
0.9226 |
0.8824 |
0.8942 |
0.9219 |
sst2 |
0.8392 |
0.8987 |
0.9247 |
0.8518 |
0.8979 |
0.9269 |
sst5 |
0.2865 |
0.3779 |
0.2891 |
0.2424 |
0.2789 |
0.3900 |
20_news_groups |
0.4572 |
0.3953 |
0.4083 |
0.3366 |
0.3576 |
0.3863 |
spam |
0.5118 |
0.5126 |
0.3642 |
0.4089 |
0.4938 |
0.3661 |
rotten_tomatoes |
0.8015 |
0.8429 |
0.8807 |
0.7987 |
0.8508 |
0.8808 |
massive |
0.3180 |
0.4635 |
0.5606 |
0.2546 |
0.1893 |
0.4376 |
banking |
0.1768 |
0.4396 |
0.3317 |
0.1374 |
0.2077 |
0.2847 |
yahoo_topics |
0.4686 |
0.4784 |
0.4760 |
0.4477 |
0.4516 |
0.4921 |
financial_phrasebank |
0.8665 |
0.8880 |
0.9044 |
0.8901 |
0.8955 |
0.8735 |
imdb |
0.9048 |
0.9351 |
0.9429 |
0.8982 |
0.9238 |
0.9333 |
ag_news |
0.7252 |
0.6985 |
0.7559 |
0.7242 |
0.6848 |
0.7503 |
dair_emotion |
0.4012 |
0.3516 |
0.3951 |
0.3450 |
0.2357 |
0.4013 |
capsotu |
0.3794 |
0.4643 |
0.4749 |
0.3432 |
0.4375 |
0.4644 |
Average: |
0.5732 |
0.6183 |
0.6165 |
0.5401 |
0.5571 |
0.6078 |
以下に、より多くの例を提供することでモデルの性能がどのように向上するかを示します。
モデル |
例の数 |
sst5 |
spam |
massive |
banking |
ag news |
dair emotion |
capsotu |
平均 |
gliclass-small-v1.0-lw |
0 |
0.2865 |
0.5118 |
0.318 |
0.1768 |
0.7252 |
0.4012 |
0.3794 |
0.3998428571 |
gliclass-base-v1.0-lw |
0 |
0.3779 |
0.5126 |
0.4635 |
0.4396 |
0.6985 |
0.3516 |
0.4643 |
0.4725714286 |
gliclass-large-v1.0-lw |
0 |
0.2891 |
0.3642 |
0.5606 |
0.3317 |
0.7559 |
0.3951 |
0.4749 |
0.4530714286 |
gliclass-small-v1.0 |
0 |
0.2424 |
0.4089 |
0.2546 |
0.1374 |
0.7242 |
0.345 |
0.3432 |
0.3508142857 |
gliclass-base-v1.0 |
0 |
0.2789 |
0.4938 |
0.1893 |
0.2077 |
0.6848 |
0.2357 |
0.4375 |
0.3611 |
gliclass-large-v1.0 |
0 |
0.39 |
0.3661 |
0.4376 |
0.2847 |
0.7503 |
0.4013 |
0.4644 |
0.4420571429 |
gliclass-small-v1.0-lw |
8 |
0.2709 |
0.84026 |
0.62 |
0.6883 |
0.7786 |
0.449 |
0.4918 |
0.5912657143 |
gliclass-base-v1.0-lw |
8 |
0.4275 |
0.8836 |
0.729 |
0.7667 |
0.7968 |
0.3866 |
0.4858 |
0.6394285714 |
gliclass-large-v1.0-lw |
8 |
0.3345 |
0.8997 |
0.7658 |
0.848 |
0.84843 |
0.5219 |
0.508 |
0.67519 |
gliclass-small-v1.0 |
8 |
0.3042 |
0.5683 |
0.6332 |
0.7072 |
0.759 |
0.4509 |
0.4434 |
0.5523142857 |
gliclass-base-v1.0 |
8 |
0.3387 |
0.7361 |
0.7059 |
0.7456 |
0.7896 |
0.4323 |
0.4802 |
0.6040571429 |
gliclass-large-v1.0 |
8 |
0.4365 |
0.9018 |
0.77 |
0.8533 |
0.8509 |
0.5061 |
0.4935 |
0.6874428571 |
📄 ライセンス
このプロジェクトは Apache-2.0 ライセンスの下で提供されています。