๐ Zero-shot Classification Model for Korean NLI
This model is designed for zero-shot classification in the Korean language. It fine-tunes the klue/roberta-base model on the mnli and xnli datasets of kor_nli, offering high accuracy in text classification tasks.
๐ Quick Start
Prerequisites
This model has been referred to the following link: https://github.com/Huffon/klue-transformers-tutorial.git
Model Details
Property |
Details |
Model Type |
Fine-tuned klue/roberta-base on kor_nli |
Training Data |
kor_nli (mnli, xnli) |
License |
Apache-2.0 |
Metrics |
Accuracy |
Pipeline Tag |
Zero-shot Classification |
Training Parameters
train_loss |
val_loss |
acc |
epoch |
batch |
lr |
0.326 |
0.538 |
0.811 |
3 |
32 |
2e-5 |
Code Modification for Zero-shot Pipeline
For models that do not use token_type_ids
, such as RoBERTa, the zero-shot pipeline cannot be directly applied (as of transformers==4.7.0
). Therefore, you need to add the following code for conversion. This code is also a modification of the code from the above GitHub repository.
class ArgumentHandler(ABC):
"""
Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`.
"""
@abstractmethod
def __call__(self, *args, **kwargs):
raise NotImplementedError()
class CustomZeroShotClassificationArgumentHandler(ArgumentHandler):
"""
Handles arguments for zero-shot for text classification by turning each possible label into an NLI
premise/hypothesis pair.
"""
def _parse_labels(self, labels):
if isinstance(labels, str):
labels = [label.strip() for label in labels.split(",")]
return labels
def __call__(self, sequences, labels, hypothesis_template):
if len(labels) == 0 or len(sequences) == 0:
raise ValueError("You must include at least one label and at least one sequence.")
if hypothesis_template.format(labels[0]) == hypothesis_template:
raise ValueError(
(
'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
"Make sure the passed template includes formatting syntax such as {{}} where the label should go."
).format(hypothesis_template)
)
if isinstance(sequences, str):
sequences = [sequences]
labels = self._parse_labels(labels)
sequence_pairs = []
for label in labels:
sequence_pairs.append(f"{sequences} {tokenizer.sep_token} {hypothesis_template.format(label)}")
return sequence_pairs, sequences
Applying the Modified Code
You need to apply the above code when defining the classifier.
classifier = pipeline(
"zero-shot-classification",
args_parser=CustomZeroShotClassificationArgumentHandler(),
model="pongjin/roberta_with_kornli"
)
๐ป Usage Examples
Basic Usage
sequence = "๋ฐฐ๋น๋ฝ D-1 ์ฝ์คํผ, 2330์ ์์น์ธ...์ธ์ธยท๊ธฐ๊ด ์ฌ์"
candidate_labels =["์ธํ",'ํ์จ', "๊ฒฝ์ ", "๊ธ์ต", "๋ถ๋์ฐ","์ฃผ์"]
classifier(
sequence,
candidate_labels,
hypothesis_template='์ด๋ {}์ ๊ดํ ๊ฒ์ด๋ค.',
)
>>{'sequence': '๋ฐฐ๋น๋ฝ D-1 ์ฝ์คํผ, 2330์ ์์น์ธ...์ธ์ธยท๊ธฐ๊ด ์ฌ์',
'labels': ['์ฃผ์', '๊ธ์ต', '๊ฒฝ์ ', '์ธํ', 'ํ์จ', '๋ถ๋์ฐ'],
'scores': [0.5052872896194458,
0.17972524464130402,
0.13852974772453308,
0.09460823982954025,
0.042949128895998,
0.038900360465049744]}
๐ License
This model is licensed under the Apache-2.0 license.