🚀 零樣本分類模型:roberta_with_kornli
本模型基於klue/roberta-base,使用kor_nli數據集中的mnli和xnli進行微調,可用於零樣本分類任務。它解決了在特定模型下無法直接應用零樣本pipeline的問題,通過自定義參數處理類實現了零樣本分類功能。
🚀 快速開始
模型信息
屬性 |
詳情 |
模型類型 |
基於klue/roberta-base微調的零樣本分類模型 |
訓練數據 |
kor_nli數據集的mnli和xnli |
評估指標 |
準確率 |
參考鏈接
本模型參考了以下GitHub倉庫:https://github.com/Huffon/klue-transformers-tutorial.git
訓練參數
train_loss |
val_loss |
acc |
epoch |
batch |
lr |
0.326 |
0.538 |
0.811 |
3 |
32 |
2e-5 |
注意事項
RoBERTa這類不使用token_type_ids的模型,在transformers==4.7.0版本下無法直接應用零樣本pipeline。因此,需要添加如下轉換代碼,該代碼是參考上述GitHub倉庫代碼修改而來。
💻 使用示例
基礎用法
class ArgumentHandler(ABC):
"""
Base interface for handling arguments for each :class:`~transformers.pipelines.Pipeline`.
"""
@abstractmethod
def __call__(self, *args, **kwargs):
raise NotImplementedError()
class CustomZeroShotClassificationArgumentHandler(ArgumentHandler):
"""
Handles arguments for zero-shot for text classification by turning each possible label into an NLI
premise/hypothesis pair.
"""
def _parse_labels(self, labels):
if isinstance(labels, str):
labels = [label.strip() for label in labels.split(",")]
return labels
def __call__(self, sequences, labels, hypothesis_template):
if len(labels) == 0 or len(sequences) == 0:
raise ValueError("You must include at least one label and at least one sequence.")
if hypothesis_template.format(labels[0]) == hypothesis_template:
raise ValueError(
(
'The provided hypothesis_template "{}" was not able to be formatted with the target labels. '
"Make sure the passed template includes formatting syntax such as {{}} where the label should go."
).format(hypothesis_template)
)
if isinstance(sequences, str):
sequences = [sequences]
labels = self._parse_labels(labels)
sequence_pairs = []
for label in labels:
sequence_pairs.append(f"{sequences} {tokenizer.sep_token} {hypothesis_template.format(label)}")
return sequence_pairs, sequences
高級用法
在定義分類器時應用上述代碼:
classifier = pipeline(
"zero-shot-classification",
args_parser=CustomZeroShotClassificationArgumentHandler(),
model="pongjin/roberta_with_kornli"
)
分類示例
sequence = "배당락 D-1 코스피, 2330선 상승세...외인·기관 사자"
candidate_labels =["외환",'환율', "경제", "금융", "부동산","주식"]
classifier(
sequence,
candidate_labels,
hypothesis_template='이는 {}에 관한 것이다.',
)
>>{'sequence': '배당락 D-1 코스피, 2330선 상승세...외인·기관 사자',
'labels': ['주식', '금융', '경제', '외환', '환율', '부동산'],
'scores': [0.5052872896194458,
0.17972524464130402,
0.13852974772453308,
0.09460823982954025,
0.042949128895998,
0.038900360465049744]}
📄 許可證
本項目採用Apache-2.0許可證。