🚀 gte-multilingual-base-xnli-anli
このモデルは、XNLIとANLIデータセットでAlibaba-NLP/gte-multilingual-baseをファインチューニングしたバージョンです。
🚀 クイックスタート
モデルの説明
mGTE: Generalized Long-Context Text Representation and Reranking Models for Multilingual Text Retrieval
Xin Zhang, Yanzhao Zhang, Dingkun Long, Wen Xie, Ziqi Dai, Jialong Tang, Huan Lin, Baosong Yang, Pengjun Xie, Fei Huang, Meishan Zhang, Wenjie Li, Min Zhang, arXiv 2024
モデルの使用方法
ゼロショット分類パイプラインを使用する場合
モデルは、以下のようにzero-shot-classification
パイプラインでロードできます。
from transformers import AutoTokenizer, pipeline
model = "mjwong/gte-multilingual-base-xnli-anli"
tokenizer = AutoTokenizer.from_pretrained(model)
classifier = pipeline("zero-shot-classification",
model=model,
tokenizer=tokenizer,
trust_remote_code=True
)
このパイプラインを使用して、指定したクラス名のいずれかにシーケンスを分類できます。
sequence_to_classify = "one day I will see the world"
candidate_labels = ['travel', 'cooking', 'dancing']
classifier(sequence_to_classify, candidate_labels)
複数の候補ラベルが正しい場合、multi_class=True
を渡して各クラスを独立して計算できます。
candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
classifier(sequence_to_classify, candidate_labels, multi_class=True)
手動でPyTorchを使用する場合
モデルは、以下のようにNLIタスクにも適用できます。
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "mjwong/gte-multilingual-base-xnli-anli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
premise = "But I thought you'd sworn off coffee."
hypothesis = "I thought that you vowed to drink more coffee."
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 2) for pred, name in zip(prediction, label_names)}
print(prediction)
評価結果
モデルは、XNLIのテストセットを使用して、15の言語(英語 (en)、アラビア語 (ar)、ブルガリア語 (bg)、ドイツ語 (de)、ギリシャ語 (el)、スペイン語 (es)、フランス語 (fr)、ヒンディー語 (hi)、ロシア語 (ru)、スワヒリ語 (sw)、タイ語 (th)、トルコ語 (tr)、ウルドゥー語 (ur)、ベトナム語 (vi)、中国語 (zh))で評価されました。使用された指標は正解率です。
データセット |
en |
ar |
bg |
de |
el |
es |
fr |
hi |
ru |
sw |
th |
tr |
ur |
vi |
zh |
gte-multilingual-base-xnli |
0.854 |
0.767 |
0.811 |
0.798 |
0.801 |
0.820 |
0.818 |
0.753 |
0.792 |
0.719 |
0.766 |
0.769 |
0.701 |
0.799 |
0.798 |
gte-multilingual-base-xnli-anli |
0.843 |
0.738 |
0.793 |
0.773 |
0.776 |
0.801 |
0.788 |
0.727 |
0.775 |
0.689 |
0.746 |
0.747 |
0.687 |
0.773 |
0.779 |
モデルは、MultiNLIの開発セットとANLIのテストセットを使用しても評価されました。使用された指標は正解率です。
学習ハイパーパラメータ
学習中に以下のハイパーパラメータが使用されました。
- learning_rate: 2e-05
- train_batch_size: 16
- eval_batch_size: 16
- seed: 42
- optimizer: Adam with betas=(0.9,0.999) and epsilon=1e-08
- lr_scheduler_type: linear
- lr_scheduler_warmup_ratio: 0.1
フレームワークのバージョン
- Transformers 4.41.0
- Pytorch 2.6.0+cu124
- Datasets 3.2.0
- Tokenizers 0.19.1
📄 ライセンス
このモデルは、Apache-2.0ライセンスの下で提供されています。
📚 ドキュメント
モデルの属性情報
属性 |
詳細 |
モデルタイプ |
ゼロショット分類 |
ベースモデル |
Alibaba-NLP/gte-multilingual-base |
学習データ |
xnli、facebook/anli |