模型简介
模型特点
模型能力
使用案例
🚀 微调版DistilBERT路透社多标签文本分类模型
本项目基于distilbert-base-cased
模型,在路透社21578(reuters21578)多标签数据集上进行微调,得到一个用于新闻多标签分类的模型。该模型在多个评估指标上表现出色,能有效处理多标签分类任务,为新闻分类提供了更精准的解决方案。
🚀 快速开始
推理示例
from transformers import pipeline
pipe = pipeline("text-classification", model="lxyuan/distilbert-finetuned-reuters21578-multilabel", return_all_scores=True)
# dataset["test"]["text"][2]
news_article = (
"JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWARDS The Ministry of International Trade and "
"Industry (MITI) will revise its long-term energy supply/demand "
"outlook by August to meet a forecast downtrend in Japanese "
"energy demand, ministry officials said. "
"MITI is expected to lower the projection for primary energy "
"supplies in the year 2000 to 550 mln kilolitres (kl) from 600 "
"mln, they said. "
"The decision follows the emergence of structural changes in "
"Japanese industry following the rise in the value of the yen "
"and a decline in domestic electric power demand. "
"MITI is planning to work out a revised energy supply/demand "
"outlook through deliberations of committee meetings of the "
"Agency of Natural Resources and Energy, the officials said. "
"They said MITI will also review the breakdown of energy "
"supply sources, including oil, nuclear, coal and natural gas. "
"Nuclear energy provided the bulk of Japan's electric power "
"in the fiscal year ended March 31, supplying an estimated 27 "
"pct on a kilowatt/hour basis, followed by oil (23 pct) and "
"liquefied natural gas (21 pct), they noted. "
"REUTER"
)
# dataset["test"]["topics"][2]
target_topics = ['crude', 'nat-gas']
fn_kwargs={"padding": "max_length", "truncation": True, "max_length": 512}
output = pipe(example, function_to_apply="sigmoid", **fn_kwargs)
for item in output[0]:
if item["score"]>=0.5:
print(item["label"], item["score"])
>>> crude 0.7355073690414429
nat-gas 0.8600426316261292
✨ 主要特性
- 多标签分类:能够对新闻文本进行多标签分类,识别出文本中包含的多个主题。
- 高性能表现:在多个评估指标上优于Scikit-learn基线模型,如加权平均F1分数和样本平均F1分数。
- 处理类别不平衡:在处理少数类别的表现上有一定提升,尽管仍存在改进空间。
📦 安装指南
文档未提及安装步骤,故跳过此章节。
💻 使用示例
基础用法
from transformers import pipeline
pipe = pipeline("text-classification", model="lxyuan/distilbert-finetuned-reuters21578-multilabel", return_all_scores=True)
# dataset["test"]["text"][2]
news_article = (
"JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWARDS The Ministry of International Trade and "
"Industry (MITI) will revise its long-term energy supply/demand "
"outlook by August to meet a forecast downtrend in Japanese "
"energy demand, ministry officials said. "
"MITI is expected to lower the projection for primary energy "
"supplies in the year 2000 to 550 mln kilolitres (kl) from 600 "
"mln, they said. "
"The decision follows the emergence of structural changes in "
"Japanese industry following the rise in the value of the yen "
"and a decline in domestic electric power demand. "
"MITI is planning to work out a revised energy supply/demand "
"outlook through deliberations of committee meetings of the "
"Agency of Natural Resources and Energy, the officials said. "
"They said MITI will also review the breakdown of energy "
"supply sources, including oil, nuclear, coal and natural gas. "
"Nuclear energy provided the bulk of Japan's electric power "
"in the fiscal year ended March 31, supplying an estimated 27 "
"pct on a kilowatt/hour basis, followed by oil (23 pct) and "
"liquefied natural gas (21 pct), they noted. "
"REUTER"
)
# dataset["test"]["topics"][2]
target_topics = ['crude', 'nat-gas']
fn_kwargs={"padding": "max_length", "truncation": True, "max_length": 512}
output = pipe(example, function_to_apply="sigmoid", **fn_kwargs)
for item in output[0]:
if item["score"]>=0.5:
print(item["label"], item["score"])
>>> crude 0.7355073690414429
nat-gas 0.8600426316261292
高级用法
文档未提及高级用法示例,故跳过此部分。
📚 详细文档
模型起源
该模型是从distilbert-finetuned-reuters21578-multilabel分叉而来,仅生成了/onnx
目录下的ONNX版本。
动机
在路透社21578多标签数据集上进行微调是一项有价值的工作,该数据集常用于面试的带回家测试。其复杂度适合在有限时间内测试多标签分类技能,且与实际应用相关,能模拟实际挑战。通过在该数据集上进行实验,不仅有助于候选人准备面试,还能提升数据预处理、特征提取和模型评估等多种技能。
整体总结与比较表格
指标 | Scikit-learn基线模型 | Transformer模型 |
---|---|---|
微平均F1 | 0.77 | 0.86 |
宏平均F1 | 0.29 | 0.33 |
加权平均F1 | 0.70 | 0.84 |
样本平均F1 | 0.75 | 0.80 |
精确率与召回率:两个模型都更注重高精确率而非召回率。在面向客户的新闻分类模型中,精确率比召回率更重要,因为误判(假阳性)的后果比漏判(假阴性)更严重,更难向客户解释。
类别不平衡处理:两个模型在处理少数类别时表现不佳,宏平均F1分数较低。但Transformer模型在宏平均F1分数上有轻微提升(0.33 vs 0.29)。
零支持标签问题:两个模型都存在部分标签在测试集中没有样本的问题,这会严重影响性能指标,可能表明模型对少数类别预测能力不足,或者数据集本身缺乏这些类别的样本。
总体性能:Transformer模型在加权平均F1分数和样本平均F1分数上超过Scikit-learn基线模型,表明其整体性能更好,对标签不平衡的处理能力更强。
结论:两个模型都具有较高的精确率,但Transformer模型在所有考虑的指标上都略优于Scikit-learn基线模型,在精确率和召回率之间取得了更好的平衡,在处理少数类别方面也有一定改进。因此,尽管与基线模型存在相似的弱点,但Transformer模型的增量改进在生产环境中可能具有重要意义。
训练和评估数据
使用以下代码从训练集和测试集中移除仅出现一次的标签:
# Find Single Appearance Labels
def find_single_appearance_labels(y):
"""Find labels that appear only once in the dataset."""
all_labels = list(chain.from_iterable(y))
label_count = Counter(all_labels)
single_appearance_labels = [label for label, count in label_count.items() if count == 1]
return single_appearance_labels
# Remove Single Appearance Labels from Dataset
def remove_single_appearance_labels(dataset, single_appearance_labels):
"""Remove samples with single-appearance labels from both train and test sets."""
for split in ['train', 'test']:
dataset[split] = dataset[split].filter(lambda x: all(label not in single_appearance_labels for label in x['topics']))
return dataset
dataset = load_dataset("reuters21578", "ModApte")
# Find and Remove Single Appearance Labels
y_train = [item['topics'] for item in dataset['train']]
single_appearance_labels = find_single_appearance_labels(y_train)
print(f"Single appearance labels: {single_appearance_labels}")
>>> Single appearance labels: ['lin-oil', 'rye', 'red-bean', 'groundnut-oil', 'citruspulp', 'rape-meal', 'corn-oil', 'peseta', 'cotton-oil', 'ringgit', 'castorseed', 'castor-oil', 'lit', 'rupiah', 'skr', 'nkr', 'dkr', 'sun-meal', 'lin-meal', 'cruzado']
print("Removing samples with single-appearance labels...")
dataset = remove_single_appearance_labels(dataset, single_appearance_labels)
unique_labels = set(chain.from_iterable(dataset['train']["topics"]))
print(f"We have {len(unique_labels)} unique labels:\n{unique_labels}")
>>> We have 95 unique labels:
{'veg-oil', 'gold', 'platinum', 'ipi', 'acq', 'carcass', 'wool', 'coconut-oil', 'linseed', 'copper', 'soy-meal', 'jet', 'dlr', 'copra-cake', 'hog', 'rand', 'strategic-metal', 'can', 'tea', 'sorghum', 'livestock', 'barley', 'lumber', 'earn', 'wheat', 'trade', 'soy-oil', 'cocoa', 'inventories', 'income', 'rubber', 'tin', 'iron-steel', 'ship', 'rapeseed', 'wpi', 'sun-oil', 'pet-chem', 'palmkernel', 'nat-gas', 'gnp', 'l-cattle', 'propane', 'rice', 'lead', 'alum', 'instal-debt', 'saudriyal', 'cpu', 'jobs', 'meal-feed', 'oilseed', 'dmk', 'plywood', 'zinc', 'retail', 'dfl', 'cpi', 'crude', 'pork-belly', 'gas', 'money-fx', 'corn', 'tapioca', 'palladium', 'lei', 'cornglutenfeed', 'sunseed', 'potato', 'silver', 'sugar', 'grain', 'groundnut', 'naphtha', 'orange', 'soybean', 'coconut', 'stg', 'cotton', 'yen', 'rape-oil', 'palm-oil', 'oat', 'reserves', 'housing', 'interest', 'coffee', 'fuel', 'austdlr', 'money-supply', 'heat', 'fishmeal', 'bop', 'nickel', 'nzdlr'}
训练过程
- 路透社21578数据集的探索性数据分析:该笔记本对路透社21578数据集进行了探索性数据分析,包括可视化和统计摘要,有助于了解数据集的结构、标签分布和文本特征。
- 路透社Scikit-learn基线模型:该笔记本使用Scikit-learn为路透社21578数据集建立了文本分类的基线模型,涵盖数据预处理、特征提取、模型训练和评估。
- 路透社Transformer模型:该笔记本深入探讨了在路透社21578数据集上使用Transformer模型进行高级文本分类的实现细节、训练过程和性能指标。
- 路透社数据集的多标签分层采样和超参数搜索:该笔记本通过Hugging Face Trainer API探索了多标签迭代分层拆分和超参数搜索等高级机器学习技术,前者旨在在k折交叉验证中公平分配不平衡数据集,后者指导用户进行结构化的超参数搜索以优化模型性能。
评估结果
Transformer模型评估结果
Classification Report: precision recall f1-score support acq 0.97 0.93 0.95 719
alum 1.00 0.70 0.82 23
austdlr 0.00 0.00 0.00 0
barley 1.00 0.50 0.67 12
bop 0.79 0.50 0.61 30
can 0.00 0.00 0.00 0
carcass 0.67 0.67 0.67 18
cocoa 1.00 1.00 1.00 18
coconut 0.00 0.00 0.00 2
coconut-oil 0.00 0.00 0.00 2
coffee 0.86 0.89 0.87 27
copper 1.00 0.78 0.88 18
copra-cake 0.00 0.00 0.00 1
corn 0.84 0.87 0.86 55
cornglutenfeed 0.00 0.00 0.00 0
cotton 0.92 0.67 0.77 18
cpi 0.86 0.43 0.57 28
cpu 0.00 0.00 0.00 1
crude 0.87 0.93 0.90 189
dfl 0.00 0.00 0.00 1
dlr 0.72 0.64 0.67 44
dmk 0.00 0.00 0.00 4
earn 0.98 0.99 0.98 1087
fishmeal 0.00 0.00 0.00 0
fuel 0.00 0.00 0.00 10
gas 0.80 0.71 0.75 17
gnp 0.79 0.66 0.72 35
gold 0.95 0.67 0.78 30
grain 0.94 0.92 0.93 146
groundnut 0.00 0.00 0.00 4
heat 0.00 0.00 0.00 5
hog 1.00 0.33 0.50 6
housing 0.00 0.00 0.00 4
income 0.00 0.00 0.00 7
instal-debt 0.00 0.00 0.00 1
interest 0.89 0.67 0.77 131
inventories 0.00 0.00 0.00 0
ipi 1.00 0.58 0.74 12
iron-steel 0.90 0.64 0.75 14
jet 0.00 0.00 0.00 1
jobs 0.92 0.57 0.71 21
l-cattle 0.00 0.00 0.00 2
lead 0.00 0.00 0.00 14
lei 0.00 0.00 0.00 3
linseed 0.00 0.00 0.00 0
livestock 0.63 0.79 0.70 24
lumber 0.00 0.00 0.00 6
meal-feed 0.00 0.00 0.00 17
money-fx 0.78 0.81 0.80 177
money-supply 0.80 0.71 0.75 34
naphtha 0.00 0.00 0.00 4
nat-gas 0.82 0.60 0.69 30
nickel 0.00 0.00 0.00 1
nzdlr 0.00 0.00 0.00 2
oat 0.00 0.00 0.00 4
oilseed 0.64 0.61 0.63 44
orange 1.00 0.36 0.53 11
palladium 0.00 0.00 0.00 1
palm-oil 1.00 0.56 0.71 9
palmkernel 0.00 0.00 0.00 1
pet-chem 0.00 0.00 0.00 12
platinum 0.00 0.00 0.00 7
plywood 0.00 0.00 0.00 0
pork-belly 0.00 0.00 0.00 0
potato 0.00 0.00 0.00 3
propane 0.00 0.00 0.00 3
rand 0.00 0.00 0.00 1
rape-oil 0.00 0.00 0.00 1
rapeseed 0.00 0.00 0.00 8
reserves 0.83 0.56 0.67 18
retail 0.00 0.00 0.00 2
rice 1.00 0.57 0.72 23
rubber 0.82 0.75 0.78 12
saudriyal 0.00 0.00 0.00 0
ship 0.95 0.81 0.87 89
silver 1.00 0.12 0.22 8
sorghum 1.00 0.12 0.22 8
soy-meal 0.00 0.00 0.00 12
soy-oil 0.00 0.00 0.00 8
soybean 0.72 0.56 0.63 32
stg 0.00 0.00 0.00 0
strategic-metal 0.00 0.00 0.00 11
sugar 1.00 0.80 0.89 35
sun-oil 0.00 0.00 0.00 0
sunseed 0.00 0.00 0.00 5
tapioca 0.00 0.00 0.00 0
tea 0.00 0.00 0.00 3
tin 1.00 0.42 0.59 12
trade 0.78 0.79 0.79 116
veg-oil 0.91 0.59 0.71 34
wheat 0.83 0.83 0.83 69
wool 0.00 0.00 0.00 0
wpi 0.00 0.00 0.00 10
yen 0.57 0.29 0.38 14
zinc 1.00 0.69 0.82 13
micro avg 0.92 0.81 0.86 3694
macro avg 0.41 0.30 0.33 3694
weighted avg 0.87 0.81 0.84 3694 samples avg 0.81 0.80 0.80 3694
Scikit-learn基线模型评估结果
Classification Report: precision recall f1-score support acq 0.98 0.87 0.92 719
alum 1.00 0.00 0.00 23
austdlr 1.00 1.00 1.00 0
barley 1.00 0.00 0.00 12
bop 1.00 0.30 0.46 30
can 1.00 1.00 1.00 0
carcass 1.00 0.06 0.11 18
cocoa 1.00 0.61 0.76 18
coconut 1.00 0.00 0.00 2
coconut-oil 1.00 0.00 0.00 2
coffee 0.94 0.59 0.73 27
copper 1.00 0.22 0.36 18
copra-cake 1.00 0.00 0.00 1
corn 0.97 0.51 0.67 55
cornglutenfeed 1.00 1.00 1.00 0
cotton 1.00 0.06 0.11 18
cpi 1.00 0.14 0.25 28
cpu 1.00 0.00 0.00 1
crude 0.94 0.69 0.80 189
dfl 1.00 0.00 0.00 1
dlr 0.86 0.43 0.58 44
dmk 1.00 0.00 0.00 4
earn 0.99 0.97 0.98 1087
fishmeal 1.00 1.00 1.00 0
fuel 1.00 0.00 0.00 10
gas 1.00 0.00 0.00 17
gnp 1.00 0.31 0.48 35
gold 0.83 0.17 0.28 30
grain 1.00 0.65 0.79 146
groundnut 1.00 0.00 0.00 4
heat 1.00 0.00 0.00 5
hog 1.00 0.00 0.00 6
housing 1.00 0.00 0.00 4
income 1.00 0.00 0.00 7
instal-debt 1.00 0.00 0.00 1
interest 0.88 0.40 0.55 131
inventories 1.00 1.00 1.00 0
ipi 1.00 0.00 0.00 12
iron-steel 1.00 0.00 0.00 14
jet 1.00 0.00 0.00 1
jobs 1.00 0.14 0.25 21
l-cattle 1.00 0.00 0.00 2
lead 1.00 0.00 0.00 14
lei 1.00 0.00 0.00 3
linseed 1.00 1.00 1.00 0
livestock 0.67 0.08 0.15 24
lumber 1.00 0.00 0.00 6
meal-feed 1.00 0.00 0.00 17
money-fx 0.80 0.50 0.62 177
money-supply 0.88 0.41 0.56 34
naphtha 1.00 0.00 0.00 4
nat-gas 1.00 0.27 0.42 30
nickel 1.00 0.00 0.00 1
nzdlr 1.00 0.00 0.00 2
oat 1.00 0.00 0.00 4
oilseed 0.62 0.11 0.19 44
orange 1.00 0.00 0.00 11
palladium 1.00 0.00 0.00 1
palm-oil 1.00 0.22 0.36 9
palmkernel 1.00 0.00 0.00 1
pet-chem 1.00 0.00 0.00 12
platinum 1.00 0.00 0.00 7
plywood 1.00 1.00 1.00 0
pork-belly 1.00 1.00 1.00 0
potato 1.00 0.00 0.00 3
propane 1.00 0.00 0.00 3
rand 1.00 0.00 0.00 1
rape-oil 1.00 0.00 0.00 1
rapeseed 1.00 0.00 0.00 8
reserves 1.00 0.00 0.00 18
retail 1.00 0.00 0.00 2
rice 1.00 0.00 0.00 23
rubber 1.00 0.17 0.29 12
saudriyal 1.00 1.00 1.00 0
ship 0.92 0.26 0.40 89
silver 1.00 0.00 0.00 8
sorghum 1.00 0.00 0.00 8
soy-meal 1.00 0.00 0.00 12
soy-oil 1.00 0.00 0.00 8
soybean 1.00 0.16 0.27 32
stg 1.00 1.00 1.00 0
strategic-metal 1.00 0.00 0.00 11
sugar 1.00 0.60 0.75 35
sun-oil 1.00 1.00 1.00 0
sunseed 1.00 0.00 0.00 5
tapioca 1.00 1.00 1.00 0
tea 1.00 0.00 0.00 3
tin 1.00 0.00 0.00 12
trade 0.92 0.61 0.74 116
veg-oil 1.00 0.12 0.21 34
wheat 0.97 0.55 0.70 69
wool 1.00 1.00 1.00 0
wpi 1.00 0.00 0.00 10
yen 1.00 0.00 0.00 14
zinc 1.00 0.00 0.00 13
micro avg 0.97 0.64 0.77 3694
macro avg 0.98 0.25 0.29 3694
weighted avg 0.96 0.64 0.70 3694 samples avg 0.98 0.74 0.75 3694
训练超参数
训练过程中使用的超参数如下:
- 学习率:2e-05
- 训练批次大小:32
- 评估批次大小:32
- 随机种子:42
- 优化器:Adam(β1=0.9,β2=0.999,ε=1e-08)
- 学习率调度器类型:线性
- 训练轮数:20
训练结果
训练损失 | 轮数 | 步数 | 验证损失 | F1分数 | ROC AUC | 准确率 |
---|---|---|---|---|---|---|
0.1801 | 1.0 | 300 | 0.0439 | 0.3896 | 0.6210 | 0.3566 |
0.0345 | 2.0 | 600 | 0.0287 | 0.6289 | 0.7318 | 0.5954 |
0.0243 | 3.0 | 900 | 0.0219 | 0.6721 | 0.7579 | 0.6084 |
0.0178 | 4.0 | 1200 | 0.0177 | 0.7505 | 0.8128 | 0.6908 |
0.014 | 5.0 | 1500 | 0.0151 | 0.7905 | 0.8376 | 0.7278 |
0.0115 | 6.0 | 1800 | 0.0135 | 0.8132 | 0.8589 | 0.7555 |
0.0096 | 7.0 | 2100 | 0.0124 | 0.8291 | 0.8727 | 0.7725 |
0.0082 | 8.0 | 2400 | 0.0124 | 0.8335 | 0.8757 | 0.7822 |
0.0071 | 9.0 | 2700 | 0.0119 | 0.8392 | 0.8847 | 0.7883 |
0.0064 | 10.0 | 3000 | 0.0123 | 0.8339 | 0.8810 | 0.7828 |
0.0058 | 11.0 | 3300 | 0.0114 | 0.8538 | 0.8999 | 0.8047 |
0.0053 | 12.0 | 3600 | 0.0113 | 0.8525 | 0.8967 | 0.8044 |
0.0048 | 13.0 | 3900 | 0.0115 | 0.8520 | 0.8982 | 0.8029 |
0.0045 | 14.0 | 4200 | 0.0111 | 0.8566 | 0.8962 | 0.8104 |
0.0042 | 15.0 | 4500 | 0.0110 | 0.8610 | 0.9060 | 0.8165 |
0.0039 | 16.0 | 4800 | 0.0112 | 0.8583 | 0.9021 | 0.8138 |
0.0037 | 17.0 | 5100 | 0.0110 | 0.8620 | 0.9055 | 0.8196 |
0.0035 | 18.0 | 5400 | 0.0110 | 0.8629 | 0.9063 | 0.8196 |
0.0035 | 19.0 | 5700 | 0.0111 | 0.8624 | 0.9062 | 0.8180 |
0.0034 | 20.0 | 6000 | 0.0111 | 0.8626 | 0.9055 | 0.8177 |
框架版本
- Transformers 4.33.0.dev0
- Pytorch 2.0.1+cu117
- Datasets 2.14.3
- Tokenizers 0.13.3
🔧 技术细节
文档未提供详细技术实现细节,故跳过此章节。
📄 许可证
本模型采用Apache 2.0许可证。








