🚀 gte-multilingual-base-xnli
このモデルは、XNLIデータセットでAlibaba-NLP/gte-multilingual-baseをファインチューニングしたバージョンです。XNLIデータセットを用いて、多言語に対応したゼロショット分類タスクに特化した性能を持ちます。
✨ 主な機能
モデルの説明
mGTE: Generalized Long-Context Text Representation and Reranking Models for Multilingual Text Retrieval。
Xin Zhang, Yanzhao Zhang, Dingkun Long, Wen Xie, Ziqi Dai, Jialong Tang, Huan Lin, Baosong Yang, Pengjun Xie, Fei Huang, Meishan Zhang, Wenjie Li, Min Zhang, arXiv 2024
📦 インストール
このモデルを使用するには、必要なライブラリをインストールする必要があります。以下のコマンドを使用して、transformers
などのライブラリをインストールできます。
pip install transformers datasets torch tokenizers
💻 使用例
基本的な使用法
ゼロショット分類パイプラインを使用する場合
このモデルは、zero-shot-classification
パイプラインを使って以下のようにロードできます。
from transformers import AutoTokenizer, pipeline
model = "mjwong/gte-multilingual-base-xnli"
tokenizer = AutoTokenizer.from_pretrained(model)
classifier = pipeline("zero-shot-classification",
model=model,
tokenizer=tokenizer,
trust_remote_code=True
)
次に、このパイプラインを使って、指定したクラス名にシーケンスを分類できます。
sequence_to_classify = "one day I will see the world"
candidate_labels = ['travel', 'cooking', 'dancing']
classifier(sequence_to_classify, candidate_labels)
複数の候補ラベルが正しい場合、multi_class=True
を渡して各クラスを独立して計算できます。
candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
classifier(sequence_to_classify, candidate_labels, multi_class=True)
手動でPyTorchを使用する場合
このモデルは、NLIタスクにも適用できます。
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "mjwong/gte-multilingual-base-xnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
premise = "But I thought you'd sworn off coffee."
hypothesis = "I thought that you vowed to drink more coffee."
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 2) for pred, name in zip(prediction, label_names)}
print(prediction)
高度な使用法
このモデルは、XNLIデータセットを用いて15の言語で評価されています。また、MultiNLIの開発セットとANLIのテストセットを用いても評価されています。評価指標は精度です。
XNLIテストセットによる評価結果
Datasets |
en |
ar |
bg |
de |
el |
es |
fr |
hi |
ru |
sw |
th |
tr |
ur |
vi |
zh |
gte-multilingual-base-xnli |
0.854 |
0.767 |
0.811 |
0.798 |
0.801 |
0.820 |
0.818 |
0.753 |
0.792 |
0.719 |
0.766 |
0.769 |
0.701 |
0.799 |
0.798 |
gte-multilingual-base-xnli-anli |
0.843 |
0.738 |
0.793 |
0.773 |
0.776 |
0.801 |
0.788 |
0.727 |
0.775 |
0.689 |
0.746 |
0.747 |
0.687 |
0.773 |
0.779 |
MultiNLI開発セットとANLIテストセットによる評価結果
🔧 技術詳細
学習ハイパーパラメータ
学習時には、以下のハイパーパラメータが使用されました。
- learning_rate: 2e-05
- train_batch_size: 16
- eval_batch_size: 16
- seed: 42
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- lr_scheduler_warmup_ratio: 0.1
フレームワークバージョン
- Transformers 4.41.0
- Pytorch 2.6.0+cu124
- Datasets 3.2.0
- Tokenizers 0.19.1
📄 ライセンス
このモデルは、Apache-2.0ライセンスの下で提供されています。