🚀 gte-multilingual-base-xnli
本模型是 Alibaba-NLP/gte-multilingual-base 在XNLI数据集上的微调版本。它可用于多语言的零样本分类任务,为不同语言的文本分类提供了强大的支持。
✨ 主要特性
- 基于多语言基础模型微调,支持多种语言,包括英语、阿拉伯语、保加利亚语等15种语言。
- 可通过零样本分类管道或手动使用PyTorch进行NLI任务。
- 在多个数据集上进行了评估,展示了良好的性能。
📦 安装指南
文档中未提及具体安装步骤,若需使用该模型,可参考Hugging Face的相关文档进行安装。
💻 使用示例
基础用法
使用零样本分类管道加载模型:
from transformers import AutoTokenizer, pipeline
model = "mjwong/gte-multilingual-base-xnli"
tokenizer = AutoTokenizer.from_pretrained(model)
classifier = pipeline("zero-shot-classification",
model=model,
tokenizer=tokenizer,
trust_remote_code=True
)
使用该管道对序列进行分类:
sequence_to_classify = "one day I will see the world"
candidate_labels = ['travel', 'cooking', 'dancing']
classifier(sequence_to_classify, candidate_labels)
若多个候选标签可能正确,可传递 multi_class=True
独立计算每个类别:
candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
classifier(sequence_to_classify, candidate_labels, multi_class=True)
高级用法
手动使用PyTorch在NLI任务上应用模型:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "mjwong/gte-multilingual-base-xnli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
premise = "But I thought you'd sworn off coffee."
hypothesis = "I thought that you vowed to drink more coffee."
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 2) for pred, name in zip(prediction, label_names)}
print(prediction)
📚 详细文档
评估结果
模型在XNLI测试集的15种语言上进行了评估,使用的指标是准确率:
数据集 |
英语 (en) |
阿拉伯语 (ar) |
保加利亚语 (bg) |
德语 (de) |
希腊语 (el) |
西班牙语 (es) |
法语 (fr) |
印地语 (hi) |
俄语 (ru) |
斯瓦希里语 (sw) |
泰语 (th) |
土耳其语 (tr) |
乌尔都语 (ur) |
越南语 (vi) |
中文 (zh) |
gte-multilingual-base-xnli |
0.854 |
0.767 |
0.811 |
0.798 |
0.801 |
0.820 |
0.818 |
0.753 |
0.792 |
0.719 |
0.766 |
0.769 |
0.701 |
0.799 |
0.798 |
gte-multilingual-base-xnli-anli |
0.843 |
0.738 |
0.793 |
0.773 |
0.776 |
0.801 |
0.788 |
0.727 |
0.775 |
0.689 |
0.746 |
0.747 |
0.687 |
0.773 |
0.779 |
模型还在MultiNLI的开发集和ANLI的测试集上进行了评估,使用的指标同样是准确率:
训练超参数
训练过程中使用了以下超参数:
- 学习率 (learning_rate): 2e-05
- 训练批次大小 (train_batch_size): 16
- 评估批次大小 (eval_batch_size): 16
- 随机种子 (seed): 42
- 优化器 (optimizer): Adam,β=(0.9, 0.999),ε=1e-08
- 学习率调度器类型 (lr_scheduler_type): 线性
- 学习率调度器预热比例 (lr_scheduler_warmup_ratio): 0.1
框架版本
- Transformers 4.41.0
- Pytorch 2.6.0+cu124
- Datasets 3.2.0
- Tokenizers 0.19.1
🔧 技术细节
本模型基于论文 mGTE: Generalized Long-Context Text Representation and Reranking Models for Multilingual Text Retrieval 进行开发,作者包括Xin Zhang、Yanzhao Zhang、Dingkun Long等,于2024年发表在arXiv上。
📄 许可证
本模型使用Apache-2.0许可证。