🚀 gte-multilingual-base-xnli-anli
該模型是 Alibaba-NLP/gte-multilingual-base 在 XNLI 和 ANLI 數據集上的微調版本,可用於零樣本分類任務,在多語言文本分類方面表現出色。
🚀 快速開始
本模型是 Alibaba-NLP/gte-multilingual-base 在 XNLI 和 ANLI 數據集上的微調版本。
✨ 主要特性
- 多語言支持:支持英語、阿拉伯語、保加利亞語等多種語言。
- 零樣本分類:可進行零樣本分類任務。
- 微調優化:基於基礎模型在特定數據集上微調,提升性能。
📚 詳細文檔
模型描述
mGTE: Generalized Long-Context Text Representation and Reranking Models for Multilingual Text Retrieval
Xin Zhang, Yanzhao Zhang, Dingkun Long, Wen Xie, Ziqi Dai, Jialong Tang, Huan Lin, Baosong Yang, Pengjun Xie, Fei Huang, Meishan Zhang, Wenjie Li, Min Zhang, arXiv 2024
如何使用模型
💻 使用示例
基礎用法
使用 zero-shot-classification
管道加載模型:
from transformers import AutoTokenizer, pipeline
model = "mjwong/gte-multilingual-base-xnli-anli"
tokenizer = AutoTokenizer.from_pretrained(model)
classifier = pipeline("zero-shot-classification",
model=model,
tokenizer=tokenizer,
trust_remote_code=True
)
然後可以使用此管道將序列分類到指定的任何類名中:
sequence_to_classify = "one day I will see the world"
candidate_labels = ['travel', 'cooking', 'dancing']
classifier(sequence_to_classify, candidate_labels)
如果多個候選標籤可能正確,可以傳遞 multi_class=True
來獨立計算每個類:
candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
classifier(sequence_to_classify, candidate_labels, multi_class=True)
高級用法
手動使用 PyTorch 在 NLI 任務上應用模型:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "mjwong/gte-multilingual-base-xnli-anli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
premise = "But I thought you'd sworn off coffee."
hypothesis = "I thought that you vowed to drink more coffee."
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 2) for pred, name in zip(prediction, label_names)}
print(prediction)
評估結果
該模型使用 XNLI 測試集在 15 種語言上進行了評估,使用的指標是準確率:
數據集 |
英語 (en) |
阿拉伯語 (ar) |
保加利亞語 (bg) |
德語 (de) |
希臘語 (el) |
西班牙語 (es) |
法語 (fr) |
印地語 (hi) |
俄語 (ru) |
斯瓦希里語 (sw) |
泰語 (th) |
土耳其語 (tr) |
烏爾都語 (ur) |
越南語 (vi) |
中文 (zh) |
gte-multilingual-base-xnli |
0.854 |
0.767 |
0.811 |
0.798 |
0.801 |
0.820 |
0.818 |
0.753 |
0.792 |
0.719 |
0.766 |
0.769 |
0.701 |
0.799 |
0.798 |
gte-multilingual-base-xnli-anli |
0.843 |
0.738 |
0.793 |
0.773 |
0.776 |
0.801 |
0.788 |
0.727 |
0.775 |
0.689 |
0.746 |
0.747 |
0.687 |
0.773 |
0.779 |
該模型還使用 MultiNLI 的開發集和 ANLI 的測試集進行了評估,使用的指標是準確率:
訓練超參數
訓練期間使用了以下超參數:
- 學習率:2e-05
- 訓練批次大小:16
- 評估批次大小:16
- 隨機種子:42
- 優化器:Adam,β=(0.9, 0.999),ε=1e-08
- 學習率調度器類型:線性
- 學習率調度器熱身比例:0.1
框架版本
- Transformers 4.41.0
- Pytorch 2.6.0+cu124
- Datasets 3.2.0
- Tokenizers 0.19.1
📄 許可證
本模型採用 Apache-2.0 許可證。
📦 相關信息表格
屬性 |
詳情 |
模型類型 |
零樣本分類模型 |
訓練數據 |
XNLI、facebook/anli |
基礎模型 |
Alibaba-NLP/gte-multilingual-base |