Model Overview
Model Features
Model Capabilities
Use Cases
🚀 mt5-large-finetuned-mnli-xtreme-xnli
This model fine - tunes a pre - trained large multilingual - t5 (also available from models) on English MNLI and the xtreme_xnli training set. It's designed for zero - shot text classification.
🚀 Quick Start
This model is suitable for zero - shot text classification, especially in non - English languages. It's fine - tuned on English MNLI and the xtreme_xnli training set, a multilingual NLI dataset. You can use it with any language in the XNLI corpus, including Arabic, Bulgarian, Chinese, etc.
✨ Features
- Multilingual Support: Can be used with multiple languages in the XNLI corpus.
- Zero - shot Classification: Enables zero - shot text classification inspired by [xlm - roberta - large - xnli](https://huggingface.co/joeddav/xlm - roberta - large - xnli).
💻 Usage Examples
Basic Usage
from torch.nn.functional import softmax
from transformers import MT5ForConditionalGeneration, MT5Tokenizer
model_name = "alan-turing-institute/mt5-large-finetuned-mnli-xtreme-xnli"
tokenizer = MT5Tokenizer.from_pretrained(model_name)
model = MT5ForConditionalGeneration.from_pretrained(model_name)
model.eval()
sequence_to_classify = "¿A quién vas a votar en 2020?"
candidate_labels = ["Europa", "salud pública", "política"]
hypothesis_template = "Este ejemplo es {}."
ENTAILS_LABEL = "▁0"
NEUTRAL_LABEL = "▁1"
CONTRADICTS_LABEL = "▁2"
label_inds = tokenizer.convert_tokens_to_ids(
[ENTAILS_LABEL, NEUTRAL_LABEL, CONTRADICTS_LABEL])
def process_nli(premise: str, hypothesis: str):
""" process to required xnli format with task prefix """
return "".join(['xnli: premise: ', premise, ' hypothesis: ', hypothesis])
# construct sequence of premise, hypothesis pairs
pairs = [(sequence_to_classify, hypothesis_template.format(label)) for label in
candidate_labels]
# format for mt5 xnli task
seqs = [process_nli(premise=premise, hypothesis=hypothesis) for
premise, hypothesis in pairs]
print(seqs)
# ['xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es Europa.',
# 'xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es salud pública.',
# 'xnli: premise: ¿A quién vas a votar en 2020? hypothesis: Este ejemplo es política.']
inputs = tokenizer.batch_encode_plus(seqs, return_tensors="pt", padding=True)
out = model.generate(**inputs, output_scores=True, return_dict_in_generate=True,
num_beams=1)
# sanity check that our sequences are expected length (1 + start token + end token = 3)
for i, seq in enumerate(out.sequences):
assert len(
seq) == 3, f"generated sequence {i} not of expected length, 3." \
f" Actual length: {len(seq)}"
# get the scores for our only token of interest
# we'll now treat these like the output logits of a `*ForSequenceClassification` model
scores = out.scores[0]
# scores has a size of the model's vocab.
# However, for this task we have a fixed set of labels
# sanity check that these labels are always the top 3 scoring
for i, sequence_scores in enumerate(scores):
top_scores = sequence_scores.argsort()[-3:]
assert set(top_scores.tolist()) == set(label_inds), \
f"top scoring tokens are not expected for this task." \
f" Expected: {label_inds}. Got: {top_scores.tolist()}."
# cut down scores to our task labels
scores = scores[:, label_inds]
print(scores)
# tensor([[-2.5697, 1.0618, 0.2088],
# [-5.4492, -2.1805, -0.1473],
# [ 2.2973, 3.7595, -0.1769]])
# new indices of entailment and contradiction in scores
entailment_ind = 0
contradiction_ind = 2
# we can show, per item, the entailment vs contradiction probas
entail_vs_contra_scores = scores[:, [entailment_ind, contradiction_ind]]
entail_vs_contra_probas = softmax(entail_vs_contra_scores, dim=1)
print(entail_vs_contra_probas)
# tensor([[0.0585, 0.9415],
# [0.0050, 0.9950],
# [0.9223, 0.0777]])
# or we can show probas similar to `ZeroShotClassificationPipeline`
# this gives a zero-shot classification style output across labels
entail_scores = scores[:, entailment_ind]
entail_probas = softmax(entail_scores, dim=0)
print(entail_probas)
# tensor([7.6341e-03, 4.2873e-04, 9.9194e-01])
print(dict(zip(candidate_labels, entail_probas.tolist())))
# {'Europa': 0.007634134963154793,
# 'salud pública': 0.0004287279152777046,
# 'política': 0.9919371604919434}
Advanced Usage
The above code shows a basic way to use the model. Note that the generate
function for the TF equivalent model doesn't exactly mirror the PyTorch version, so the code won't directly transfer. Also, the model is currently not compatible with the existing zero - shot - classification
pipeline.
🔧 Technical Details
This model was pre - trained on a set of 101 languages in the mC4, as described in the mt5 paper. It was then fine - tuned on the [mt5_xnli_translate_train](https://github.com/google - research/multilingual - t5/blob/78d102c830d76bd68f27596a97617e2db2bfc887/multilingual_t5/tasks.py#L190) task for 8k steps, following the method in the [offical repo](https://github.com/google - research/multilingual - t5#fine - tuning) and guided by [Stephen Mayhew's notebook](https://github.com/mayhewsw/multilingual - t5/blob/master/notebooks/mt5 - xnli.ipynb). Finally, the resulting model was converted to :hugging_face: format.
📚 Documentation
Intended Use
This model is for zero - shot text classification, especially useful for non - English languages. If you only need English classification, you can consider [bart - large - mnli](https://huggingface.co/facebook/bart - large - mnli) or [a distilled bart MNLI model](https://huggingface.co/models?filter=pipeline_tag%3Azero - shot - classification&search=valhalla).
Eval results
The accuracy of the model over the XNLI test set is as follows:
ar | bg | de | el | en | es | fr | hi | ru | sw | th | tr | ur | vi | zh | average |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
81.0 | 85.0 | 84.3 | 84.3 | 88.8 | 85.3 | 83.9 | 79.9 | 82.6 | 78.0 | 81.0 | 81.6 | 76.4 | 81.7 | 82.3 | 82.4 |
📄 License
This model is licensed under the apache - 2.0 license.
Additional Information
Property | Details |
---|---|
Model Type | mt5 - large - finetuned - mnli - xtreme - xnli |
Training Data | multi_nli, xtreme_xnli |
Metrics | xnli |
Supported Languages | Arabic, Bulgarian, Chinese, English, French, German, Greek, Hindi, Russian, Spanish, Swahili, Thai, Turkish, Urdu, Vietnamese |
Tags | pytorch |

