🚀 DeBERTa-v3-base-mnli-fever-anli
该模型在文本分类和零样本分类任务中表现出色,基于特定数据集训练,能有效处理自然语言推理问题,为相关领域研究和应用提供了有力支持。
🚀 快速开始
简单的零样本分类管道
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli")
sequence_to_classify = "Angela Merkel is a politician in Germany and leader of the CDU"
candidate_labels = ["politics", "economy", "entertainment", "environment"]
output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
print(output)
NLI使用案例
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
premise = "I first thought that I liked the movie, but upon second thought it was actually disappointing."
hypothesis = "The movie was good."
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
print(prediction)
✨ 主要特性
- 该模型在MultiNLI、Fever - NLI和Adversarial - NLI (ANLI)数据集上进行训练,包含763913个NLI假设 - 前提对。
- 此基础模型在ANLI基准测试中几乎优于所有大型模型。
- 基础模型是微软的DeBERTa - v3 - base,DeBERTa的v3变体通过不同的预训练目标,显著优于该模型的先前版本。
📦 安装指南
在使用模型前,你需要安装transformers
库,可使用以下命令进行安装:
💻 使用示例
基础用法
from transformers import pipeline
classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli")
sequence_to_classify = "Angela Merkel is a politician in Germany and leader of the CDU"
candidate_labels = ["politics", "economy", "entertainment", "environment"]
output = classifier(sequence_to_classify, candidate_labels, multi_label=False)
print(output)
高级用法
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
premise = "I first thought that I liked the movie, but upon second thought it was actually disappointing."
hypothesis = "The movie was good."
input = tokenizer(premise, hypothesis, truncation=True, return_tensors="pt")
output = model(input["input_ids"].to(device))
prediction = torch.softmax(output["logits"][0], -1).tolist()
label_names = ["entailment", "neutral", "contradiction"]
prediction = {name: round(float(pred) * 100, 1) for pred, name in zip(prediction, label_names)}
print(prediction)
📚 详细文档
训练数据
DeBERTa - v3 - base - mnli - fever - anli在MultiNLI、Fever - NLI和Adversarial - NLI (ANLI)数据集上进行训练,这些数据集包含763913个NLI假设 - 前提对。
训练过程
DeBERTa - v3 - base - mnli - fever - anli使用Hugging Face训练器进行训练,超参数如下:
training_args = TrainingArguments(
num_train_epochs=3, # total number of training epochs
learning_rate=2e-05,
per_device_train_batch_size=32, # batch size per device during training
per_device_eval_batch_size=32, # batch size for evaluation
warmup_ratio=0.1, # number of warmup steps for learning rate scheduler
weight_decay=0.06, # strength of weight decay
fp16=True # mixed precision training
)
评估结果
该模型使用MultiNLI和ANLI的测试集以及Fever - NLI的开发集进行评估,使用的指标是准确率。
mnli - m |
mnli - mm |
fever - nli |
anli - all |
anli - r3 |
0.903 |
0.903 |
0.777 |
0.579 |
0.495 |
🔧 技术细节
该模型基于微软的DeBERTa - v3 - base
,其v3变体通过不同的预训练目标,显著优于该模型的先前版本。具体可参考原始DeBERTa论文的附录11。
📄 许可证
本项目采用MIT许可证。
⚠️ 重要提示
请参考原始DeBERTa论文和不同NLI数据集的相关文献,以了解潜在的偏差。