🚀 mt5-large-finetuned-mnli-xtreme-xnli
本模型基於預訓練的大型 multilingual-t5(也可從 models 獲取),並在英文 MNLI 和 xtreme_xnli 訓練集上進行微調。它旨在用於零樣本文本分類,靈感來源於 xlm-roberta-large-xnli。
🚀 快速開始
本模型專為零樣本文本分類而設計,尤其適用於英文以外的語言。它在英文 MNLI 和 xtreme_xnli 訓練集(一個多語言自然語言推理數據集)上進行了微調。因此,該模型可用於 XNLI 語料庫中的任何語言:
- 阿拉伯語
- 保加利亞語
- 中文
- 英語
- 法語
- 德語
- 希臘語
- 印地語
- 俄語
- 西班牙語
- 斯瓦希里語
- 泰語
- 土耳其語
- 烏爾都語
- 越南語
根據 xlm-roberta-large-xnli 中的建議,若僅進行英文分類,你可以考慮以下模型:
✨ 主要特性
- 基於預訓練的多語言 T5 模型進行微調,適用於多語言零樣本文本分類。
- 微調後保留了文本到文本的特性,輸出為文本形式。
- 可用於 XNLI 語料庫中的多種語言。
💻 使用示例
基礎用法
from torch.nn.functional import softmax
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
model_name = "alan-turing-institute/mt5-large-finetuned-mnli-xtreme-xnli"
tokenizer = MT5Tokenizer.from_pretrained(model_name)
model = MT5ForConditionalGeneration.from_pretrained(model_name)
model.eval()
sequence_to_classify = "¿A quién vas a votar en 2020?"
candidate_labels = ["Europa", "salud pública", "política"]
hypothesis_template = "Este ejemplo es {}."
ENTAILS_LABEL = "▁0"
NEUTRAL_LABEL = "▁1"
CONTRADICTS_LABEL = "▁2"
label_inds = tokenizer.convert_tokens_to_ids(
[ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL])
def process_nli(premise: str, hypothesis: str):
""" process to required xnli format with task prefix """
return "".join(['xnli: premise: ', premise, ' hypothesis: ', hypothesis])
pairs = [(sequence_to_classify, hypothesis_template.format(label)) for label in
candidate_labels]
seqs = [process_nli(premise=premise, hypothesis=hypothesis) for
premise, hypothesis in pairs]
print(seqs)
inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True)
out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True,
num_beams=1)
for i, seq in enumerate(out.sequences):
assert len(
seq) == 3, f"generated sequence {i} not of expected length, 3." \
f" Actual length: {len(seq)}"
scores = out.scores[0]
for i, sequence_scores in enumerate(scores):
top_scores = sequence_scores.argsort()[-3:]
assert set(top_scores.tolist()) == set(label_inds), \
f"top scoring tokens are not expected for this task." \
f" Expected: {label_inds}. Got: {top_scores.tolist()}."
scores = scores[:, label_inds]
print(scores)
entailment_ind = 0
contradiction_ind = 2
entail_vs_contra_scores = scores[:, [entailment_ind, contradiction_ind]]
entail_vs_contra_probas = softmax(entail_vs_contra_scores, dim=1)
print(entail_vs_contra_probas)
entail_scores = scores[:, entailment_ind]
entail_probas = softmax(entail_scores, dim=0)
print(entail_probas)
print(dict(zip(candidate_labels, entail_probas.tolist())))
高級用法
🔧 技術細節
本模型在 mC4 中的 101 種語言上進行了預訓練,如 mt5 論文 所述。然後,它在 mt5_xnli_translate_train 任務上進行了 8000 步的微調,微調方式與 官方倉庫 中描述的類似,並參考了 Stephen Mayhew 的筆記本。最後,將得到的模型轉換為 :hugging_face: 格式。
📚 詳細文檔
評估結果
XNLI 測試集上的準確率:
語言代碼 |
準確率 |
ar |
81.0 |
bg |
85.0 |
de |
84.3 |
el |
84.3 |
en |
88.8 |
es |
85.3 |
fr |
83.9 |
hi |
79.9 |
ru |
82.6 |
sw |
78.0 |
th |
81.0 |
tr |
81.6 |
ur |
76.4 |
vi |
81.7 |
zh |
82.3 |
平均 |
82.4 |
📄 許可證
本模型採用 Apache-2.0 許可證。