🚀 gte-multilingual-base-xnli-anli
该模型是 Alibaba-NLP/gte-multilingual-base 在 XNLI 和 ANLI 数据集上的微调版本,可用于零样本分类任务,在多语言文本分类方面表现出色。
🚀 快速开始
本模型是 Alibaba-NLP/gte-multilingual-base 在 XNLI 和 ANLI 数据集上的微调版本。
✨ 主要特性
- 多语言支持:支持英语、阿拉伯语、保加利亚语等多种语言。
- 零样本分类:可进行零样本分类任务。
- 微调优化:基于基础模型在特定数据集上微调,提升性能。
📚 详细文档
模型描述
mGTE: Generalized Long-Context Text Representation and Reranking Models for Multilingual Text Retrieval
Xin Zhang, Yanzhao Zhang, Dingkun Long, Wen Xie, Ziqi Dai, Jialong Tang, Huan Lin, Baosong Yang, Pengjun Xie, Fei Huang, Meishan Zhang, Wenjie Li, Min Zhang, arXiv 2024
如何使用模型
💻 使用示例
基础用法
使用 zero-shot-classification
管道加载模型:
from transformers import AutoTokenizer, pipeline
model = "mjwong/gte-multilingual-base-xnli-anli"
tokenizer = AutoTokenizer.from_pretrained(model)
classifier = pipeline("zero-shot-classification",
model=model,
tokenizer=tokenizer,
trust_remote_code=True
)
然后可以使用此管道将序列分类到指定的任何类名中:
sequence_to_classify = "one day I will see the world"
candidate_labels = ['travel', 'cooking', 'dancing']
classifier(sequence_to_classify, candidate_labels)
如果多个候选标签可能正确,可以传递 multi_class=True
来独立计算每个类:
candidate_labels = ['travel', 'cooking', 'dancing', 'exploration']
classifier(sequence_to_classify, candidate_labels, multi_class=True)
高级用法
手动使用 PyTorch 在 NLI 任务上应用模型:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "mjwong/gte-multilingual-base-xnli-anli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, trust_remote_code=True)
premise = "But I thought you'd sworn off coffee."
hypothesis = "I thought that you vowed to drink more coffee."
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 2) for pred, name in zip(prediction, label_names)}
print(prediction)
评估结果
该模型使用 XNLI 测试集在 15 种语言上进行了评估,使用的指标是准确率:
数据集 |
英语 (en) |
阿拉伯语 (ar) |
保加利亚语 (bg) |
德语 (de) |
希腊语 (el) |
西班牙语 (es) |
法语 (fr) |
印地语 (hi) |
俄语 (ru) |
斯瓦希里语 (sw) |
泰语 (th) |
土耳其语 (tr) |
乌尔都语 (ur) |
越南语 (vi) |
中文 (zh) |
gte-multilingual-base-xnli |
0.854 |
0.767 |
0.811 |
0.798 |
0.801 |
0.820 |
0.818 |
0.753 |
0.792 |
0.719 |
0.766 |
0.769 |
0.701 |
0.799 |
0.798 |
gte-multilingual-base-xnli-anli |
0.843 |
0.738 |
0.793 |
0.773 |
0.776 |
0.801 |
0.788 |
0.727 |
0.775 |
0.689 |
0.746 |
0.747 |
0.687 |
0.773 |
0.779 |
该模型还使用 MultiNLI 的开发集和 ANLI 的测试集进行了评估,使用的指标是准确率:
训练超参数
训练期间使用了以下超参数:
- 学习率:2e-05
- 训练批次大小:16
- 评估批次大小:16
- 随机种子:42
- 优化器:Adam,β=(0.9, 0.999),ε=1e-08
- 学习率调度器类型:线性
- 学习率调度器热身比例:0.1
框架版本
- Transformers 4.41.0
- Pytorch 2.6.0+cu124
- Datasets 3.2.0
- Tokenizers 0.19.1
📄 许可证
本模型采用 Apache-2.0 许可证。
📦 相关信息表格
属性 |
详情 |
模型类型 |
零样本分类模型 |
训练数据 |
XNLI、facebook/anli |
基础模型 |
Alibaba-NLP/gte-multilingual-base |