🚀 gte-multilingual-base-xnli
本模型是 Alibaba-NLP/gte-multilingual-base 在XNLI數據集上的微調版本。它可用於多語言的零樣本分類任務,為不同語言的文本分類提供了強大的支持。
✨ 主要特性
- 基於多語言基礎模型微調,支持多種語言,包括英語、阿拉伯語、保加利亞語等15種語言。
- 可通過零樣本分類管道或手動使用PyTorch進行NLI任務。
- 在多個數據集上進行了評估,展示了良好的性能。
📦 安裝指南
文檔中未提及具體安裝步驟,若需使用該模型,可參考Hugging Face的相關文檔進行安裝。
💻 使用示例
基礎用法
使用零樣本分類管道加載模型:
from transformers import AutoTokenizer, pipeline
model = "mjwong/gte-multilingual-base-xnli"
tokenizer = AutoTokenizer.from_pretrained(model)
classifier = pipeline("zero-shot-classification",
model=model,
tokenizer=tokenizer,
trust_remote_code=True
)
使用該管道對序列進行分類:
sequence_to_classify = "one day I will see the world"
candidate_labels = ['travel', 'cooking', 'dancing']
classifier(sequence_to_classify, candidate_labels)
若多個候選標籤可能正確,可傳遞 multi_class=True
獨立計算每個類別:
candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
classifier(sequence_to_classify, candidate_labels, multi_class=True)
高級用法
手動使用PyTorch在NLI任務上應用模型:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "mjwong/gte-multilingual-base-xnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
premise = "But I thought you'd sworn off coffee."
hypothesis = "I thought that you vowed to drink more coffee."
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 2) for pred, name in zip(prediction, label_names)}
print(prediction)
📚 詳細文檔
評估結果
模型在XNLI測試集的15種語言上進行了評估,使用的指標是準確率:
數據集 |
英語 (en) |
阿拉伯語 (ar) |
保加利亞語 (bg) |
德語 (de) |
希臘語 (el) |
西班牙語 (es) |
法語 (fr) |
印地語 (hi) |
俄語 (ru) |
斯瓦希里語 (sw) |
泰語 (th) |
土耳其語 (tr) |
烏爾都語 (ur) |
越南語 (vi) |
中文 (zh) |
gte-multilingual-base-xnli |
0.854 |
0.767 |
0.811 |
0.798 |
0.801 |
0.820 |
0.818 |
0.753 |
0.792 |
0.719 |
0.766 |
0.769 |
0.701 |
0.799 |
0.798 |
gte-multilingual-base-xnli-anli |
0.843 |
0.738 |
0.793 |
0.773 |
0.776 |
0.801 |
0.788 |
0.727 |
0.775 |
0.689 |
0.746 |
0.747 |
0.687 |
0.773 |
0.779 |
模型還在MultiNLI的開發集和ANLI的測試集上進行了評估,使用的指標同樣是準確率:
訓練超參數
訓練過程中使用了以下超參數:
- 學習率 (learning_rate): 2e-05
- 訓練批次大小 (train_batch_size): 16
- 評估批次大小 (eval_batch_size): 16
- 隨機種子 (seed): 42
- 優化器 (optimizer): Adam,β=(0.9, 0.999),ε=1e-08
- 學習率調度器類型 (lr_scheduler_type): 線性
- 學習率調度器預熱比例 (lr_scheduler_warmup_ratio): 0.1
框架版本
- Transformers 4.41.0
- Pytorch 2.6.0+cu124
- Datasets 3.2.0
- Tokenizers 0.19.1
🔧 技術細節
本模型基於論文 mGTE: Generalized Long-Context Text Representation and Reranking Models for Multilingual Text Retrieval 進行開發,作者包括Xin Zhang、Yanzhao Zhang、Dingkun Long等,於2024年發表在arXiv上。
📄 許可證
本模型使用Apache-2.0許可證。