模型概述
模型特點
模型能力
使用案例
🚀 微調版DistilBERT路透社多標籤文本分類模型
本項目基於distilbert-base-cased
模型,在路透社21578(reuters21578)多標籤數據集上進行微調,得到一個用於新聞多標籤分類的模型。該模型在多個評估指標上表現出色,能有效處理多標籤分類任務,為新聞分類提供了更精準的解決方案。
🚀 快速開始
推理示例
from transformers import pipeline
pipe = pipeline("text-classification", model="lxyuan/distilbert-finetuned-reuters21578-multilabel", return_all_scores=True)
# dataset["test"]["text"][2]
news_article = (
"JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWARDS The Ministry of International Trade and "
"Industry (MITI) will revise its long-term energy supply/demand "
"outlook by August to meet a forecast downtrend in Japanese "
"energy demand, ministry officials said. "
"MITI is expected to lower the projection for primary energy "
"supplies in the year 2000 to 550 mln kilolitres (kl) from 600 "
"mln, they said. "
"The decision follows the emergence of structural changes in "
"Japanese industry following the rise in the value of the yen "
"and a decline in domestic electric power demand. "
"MITI is planning to work out a revised energy supply/demand "
"outlook through deliberations of committee meetings of the "
"Agency of Natural Resources and Energy, the officials said. "
"They said MITI will also review the breakdown of energy "
"supply sources, including oil, nuclear, coal and natural gas. "
"Nuclear energy provided the bulk of Japan's electric power "
"in the fiscal year ended March 31, supplying an estimated 27 "
"pct on a kilowatt/hour basis, followed by oil (23 pct) and "
"liquefied natural gas (21 pct), they noted. "
"REUTER"
)
# dataset["test"]["topics"][2]
target_topics = ['crude', 'nat-gas']
fn_kwargs={"padding": "max_length", "truncation": True, "max_length": 512}
output = pipe(example, function_to_apply="sigmoid", **fn_kwargs)
for item in output[0]:
if item["score"]>=0.5:
print(item["label"], item["score"])
>>> crude 0.7355073690414429
nat-gas 0.8600426316261292
✨ 主要特性
- 多標籤分類:能夠對新聞文本進行多標籤分類,識別出文本中包含的多個主題。
- 高性能表現:在多個評估指標上優於Scikit-learn基線模型,如加權平均F1分數和樣本平均F1分數。
- 處理類別不平衡:在處理少數類別的表現上有一定提升,儘管仍存在改進空間。
📦 安裝指南
文檔未提及安裝步驟,故跳過此章節。
💻 使用示例
基礎用法
from transformers import pipeline
pipe = pipeline("text-classification", model="lxyuan/distilbert-finetuned-reuters21578-multilabel", return_all_scores=True)
# dataset["test"]["text"][2]
news_article = (
"JAPAN TO REVISE LONG-TERM ENERGY DEMAND DOWNWARDS The Ministry of International Trade and "
"Industry (MITI) will revise its long-term energy supply/demand "
"outlook by August to meet a forecast downtrend in Japanese "
"energy demand, ministry officials said. "
"MITI is expected to lower the projection for primary energy "
"supplies in the year 2000 to 550 mln kilolitres (kl) from 600 "
"mln, they said. "
"The decision follows the emergence of structural changes in "
"Japanese industry following the rise in the value of the yen "
"and a decline in domestic electric power demand. "
"MITI is planning to work out a revised energy supply/demand "
"outlook through deliberations of committee meetings of the "
"Agency of Natural Resources and Energy, the officials said. "
"They said MITI will also review the breakdown of energy "
"supply sources, including oil, nuclear, coal and natural gas. "
"Nuclear energy provided the bulk of Japan's electric power "
"in the fiscal year ended March 31, supplying an estimated 27 "
"pct on a kilowatt/hour basis, followed by oil (23 pct) and "
"liquefied natural gas (21 pct), they noted. "
"REUTER"
)
# dataset["test"]["topics"][2]
target_topics = ['crude', 'nat-gas']
fn_kwargs={"padding": "max_length", "truncation": True, "max_length": 512}
output = pipe(example, function_to_apply="sigmoid", **fn_kwargs)
for item in output[0]:
if item["score"]>=0.5:
print(item["label"], item["score"])
>>> crude 0.7355073690414429
nat-gas 0.8600426316261292
高級用法
文檔未提及高級用法示例,故跳過此部分。
📚 詳細文檔
模型起源
該模型是從distilbert-finetuned-reuters21578-multilabel分叉而來,僅生成了/onnx
目錄下的ONNX版本。
動機
在路透社21578多標籤數據集上進行微調是一項有價值的工作,該數據集常用於面試的帶回家測試。其複雜度適合在有限時間內測試多標籤分類技能,且與實際應用相關,能模擬實際挑戰。通過在該數據集上進行實驗,不僅有助於候選人準備面試,還能提升數據預處理、特徵提取和模型評估等多種技能。
整體總結與比較表格
指標 | Scikit-learn基線模型 | Transformer模型 |
---|---|---|
微平均F1 | 0.77 | 0.86 |
宏平均F1 | 0.29 | 0.33 |
加權平均F1 | 0.70 | 0.84 |
樣本平均F1 | 0.75 | 0.80 |
精確率與召回率:兩個模型都更注重高精確率而非召回率。在面向客戶的新聞分類模型中,精確率比召回率更重要,因為誤判(假陽性)的後果比漏判(假陰性)更嚴重,更難向客戶解釋。
類別不平衡處理:兩個模型在處理少數類別時表現不佳,宏平均F1分數較低。但Transformer模型在宏平均F1分數上有輕微提升(0.33 vs 0.29)。
零支持標籤問題:兩個模型都存在部分標籤在測試集中沒有樣本的問題,這會嚴重影響性能指標,可能表明模型對少數類別預測能力不足,或者數據集本身缺乏這些類別的樣本。
總體性能:Transformer模型在加權平均F1分數和樣本平均F1分數上超過Scikit-learn基線模型,表明其整體性能更好,對標籤不平衡的處理能力更強。
結論:兩個模型都具有較高的精確率,但Transformer模型在所有考慮的指標上都略優於Scikit-learn基線模型,在精確率和召回率之間取得了更好的平衡,在處理少數類別方面也有一定改進。因此,儘管與基線模型存在相似的弱點,但Transformer模型的增量改進在生產環境中可能具有重要意義。
訓練和評估數據
使用以下代碼從訓練集和測試集中移除僅出現一次的標籤:
# Find Single Appearance Labels
def find_single_appearance_labels(y):
"""Find labels that appear only once in the dataset."""
all_labels = list(chain.from_iterable(y))
label_count = Counter(all_labels)
single_appearance_labels = [label for label, count in label_count.items() if count == 1]
return single_appearance_labels
# Remove Single Appearance Labels from Dataset
def remove_single_appearance_labels(dataset, single_appearance_labels):
"""Remove samples with single-appearance labels from both train and test sets."""
for split in ['train', 'test']:
dataset[split] = dataset[split].filter(lambda x: all(label not in single_appearance_labels for label in x['topics']))
return dataset
dataset = load_dataset("reuters21578", "ModApte")
# Find and Remove Single Appearance Labels
y_train = [item['topics'] for item in dataset['train']]
single_appearance_labels = find_single_appearance_labels(y_train)
print(f"Single appearance labels: {single_appearance_labels}")
>>> Single appearance labels: ['lin-oil', 'rye', 'red-bean', 'groundnut-oil', 'citruspulp', 'rape-meal', 'corn-oil', 'peseta', 'cotton-oil', 'ringgit', 'castorseed', 'castor-oil', 'lit', 'rupiah', 'skr', 'nkr', 'dkr', 'sun-meal', 'lin-meal', 'cruzado']
print("Removing samples with single-appearance labels...")
dataset = remove_single_appearance_labels(dataset, single_appearance_labels)
unique_labels = set(chain.from_iterable(dataset['train']["topics"]))
print(f"We have {len(unique_labels)} unique labels:\n{unique_labels}")
>>> We have 95 unique labels:
{'veg-oil', 'gold', 'platinum', 'ipi', 'acq', 'carcass', 'wool', 'coconut-oil', 'linseed', 'copper', 'soy-meal', 'jet', 'dlr', 'copra-cake', 'hog', 'rand', 'strategic-metal', 'can', 'tea', 'sorghum', 'livestock', 'barley', 'lumber', 'earn', 'wheat', 'trade', 'soy-oil', 'cocoa', 'inventories', 'income', 'rubber', 'tin', 'iron-steel', 'ship', 'rapeseed', 'wpi', 'sun-oil', 'pet-chem', 'palmkernel', 'nat-gas', 'gnp', 'l-cattle', 'propane', 'rice', 'lead', 'alum', 'instal-debt', 'saudriyal', 'cpu', 'jobs', 'meal-feed', 'oilseed', 'dmk', 'plywood', 'zinc', 'retail', 'dfl', 'cpi', 'crude', 'pork-belly', 'gas', 'money-fx', 'corn', 'tapioca', 'palladium', 'lei', 'cornglutenfeed', 'sunseed', 'potato', 'silver', 'sugar', 'grain', 'groundnut', 'naphtha', 'orange', 'soybean', 'coconut', 'stg', 'cotton', 'yen', 'rape-oil', 'palm-oil', 'oat', 'reserves', 'housing', 'interest', 'coffee', 'fuel', 'austdlr', 'money-supply', 'heat', 'fishmeal', 'bop', 'nickel', 'nzdlr'}
訓練過程
- 路透社21578數據集的探索性數據分析:該筆記本對路透社21578數據集進行了探索性數據分析,包括可視化和統計摘要,有助於瞭解數據集的結構、標籤分佈和文本特徵。
- 路透社Scikit-learn基線模型:該筆記本使用Scikit-learn為路透社21578數據集建立了文本分類的基線模型,涵蓋數據預處理、特徵提取、模型訓練和評估。
- 路透社Transformer模型:該筆記本深入探討了在路透社21578數據集上使用Transformer模型進行高級文本分類的實現細節、訓練過程和性能指標。
- 路透社數據集的多標籤分層採樣和超參數搜索:該筆記本通過Hugging Face Trainer API探索了多標籤迭代分層拆分和超參數搜索等高級機器學習技術,前者旨在在k折交叉驗證中公平分配不平衡數據集,後者指導用戶進行結構化的超參數搜索以優化模型性能。
評估結果
Transformer模型評估結果
Classification Report: precision recall f1-score support acq 0.97 0.93 0.95 719
alum 1.00 0.70 0.82 23
austdlr 0.00 0.00 0.00 0
barley 1.00 0.50 0.67 12
bop 0.79 0.50 0.61 30
can 0.00 0.00 0.00 0
carcass 0.67 0.67 0.67 18
cocoa 1.00 1.00 1.00 18
coconut 0.00 0.00 0.00 2
coconut-oil 0.00 0.00 0.00 2
coffee 0.86 0.89 0.87 27
copper 1.00 0.78 0.88 18
copra-cake 0.00 0.00 0.00 1
corn 0.84 0.87 0.86 55
cornglutenfeed 0.00 0.00 0.00 0
cotton 0.92 0.67 0.77 18
cpi 0.86 0.43 0.57 28
cpu 0.00 0.00 0.00 1
crude 0.87 0.93 0.90 189
dfl 0.00 0.00 0.00 1
dlr 0.72 0.64 0.67 44
dmk 0.00 0.00 0.00 4
earn 0.98 0.99 0.98 1087
fishmeal 0.00 0.00 0.00 0
fuel 0.00 0.00 0.00 10
gas 0.80 0.71 0.75 17
gnp 0.79 0.66 0.72 35
gold 0.95 0.67 0.78 30
grain 0.94 0.92 0.93 146
groundnut 0.00 0.00 0.00 4
heat 0.00 0.00 0.00 5
hog 1.00 0.33 0.50 6
housing 0.00 0.00 0.00 4
income 0.00 0.00 0.00 7
instal-debt 0.00 0.00 0.00 1
interest 0.89 0.67 0.77 131
inventories 0.00 0.00 0.00 0
ipi 1.00 0.58 0.74 12
iron-steel 0.90 0.64 0.75 14
jet 0.00 0.00 0.00 1
jobs 0.92 0.57 0.71 21
l-cattle 0.00 0.00 0.00 2
lead 0.00 0.00 0.00 14
lei 0.00 0.00 0.00 3
linseed 0.00 0.00 0.00 0
livestock 0.63 0.79 0.70 24
lumber 0.00 0.00 0.00 6
meal-feed 0.00 0.00 0.00 17
money-fx 0.78 0.81 0.80 177
money-supply 0.80 0.71 0.75 34
naphtha 0.00 0.00 0.00 4
nat-gas 0.82 0.60 0.69 30
nickel 0.00 0.00 0.00 1
nzdlr 0.00 0.00 0.00 2
oat 0.00 0.00 0.00 4
oilseed 0.64 0.61 0.63 44
orange 1.00 0.36 0.53 11
palladium 0.00 0.00 0.00 1
palm-oil 1.00 0.56 0.71 9
palmkernel 0.00 0.00 0.00 1
pet-chem 0.00 0.00 0.00 12
platinum 0.00 0.00 0.00 7
plywood 0.00 0.00 0.00 0
pork-belly 0.00 0.00 0.00 0
potato 0.00 0.00 0.00 3
propane 0.00 0.00 0.00 3
rand 0.00 0.00 0.00 1
rape-oil 0.00 0.00 0.00 1
rapeseed 0.00 0.00 0.00 8
reserves 0.83 0.56 0.67 18
retail 0.00 0.00 0.00 2
rice 1.00 0.57 0.72 23
rubber 0.82 0.75 0.78 12
saudriyal 0.00 0.00 0.00 0
ship 0.95 0.81 0.87 89
silver 1.00 0.12 0.22 8
sorghum 1.00 0.12 0.22 8
soy-meal 0.00 0.00 0.00 12
soy-oil 0.00 0.00 0.00 8
soybean 0.72 0.56 0.63 32
stg 0.00 0.00 0.00 0
strategic-metal 0.00 0.00 0.00 11
sugar 1.00 0.80 0.89 35
sun-oil 0.00 0.00 0.00 0
sunseed 0.00 0.00 0.00 5
tapioca 0.00 0.00 0.00 0
tea 0.00 0.00 0.00 3
tin 1.00 0.42 0.59 12
trade 0.78 0.79 0.79 116
veg-oil 0.91 0.59 0.71 34
wheat 0.83 0.83 0.83 69
wool 0.00 0.00 0.00 0
wpi 0.00 0.00 0.00 10
yen 0.57 0.29 0.38 14
zinc 1.00 0.69 0.82 13
micro avg 0.92 0.81 0.86 3694
macro avg 0.41 0.30 0.33 3694
weighted avg 0.87 0.81 0.84 3694 samples avg 0.81 0.80 0.80 3694
Scikit-learn基線模型評估結果
Classification Report: precision recall f1-score support acq 0.98 0.87 0.92 719
alum 1.00 0.00 0.00 23
austdlr 1.00 1.00 1.00 0
barley 1.00 0.00 0.00 12
bop 1.00 0.30 0.46 30
can 1.00 1.00 1.00 0
carcass 1.00 0.06 0.11 18
cocoa 1.00 0.61 0.76 18
coconut 1.00 0.00 0.00 2
coconut-oil 1.00 0.00 0.00 2
coffee 0.94 0.59 0.73 27
copper 1.00 0.22 0.36 18
copra-cake 1.00 0.00 0.00 1
corn 0.97 0.51 0.67 55
cornglutenfeed 1.00 1.00 1.00 0
cotton 1.00 0.06 0.11 18
cpi 1.00 0.14 0.25 28
cpu 1.00 0.00 0.00 1
crude 0.94 0.69 0.80 189
dfl 1.00 0.00 0.00 1
dlr 0.86 0.43 0.58 44
dmk 1.00 0.00 0.00 4
earn 0.99 0.97 0.98 1087
fishmeal 1.00 1.00 1.00 0
fuel 1.00 0.00 0.00 10
gas 1.00 0.00 0.00 17
gnp 1.00 0.31 0.48 35
gold 0.83 0.17 0.28 30
grain 1.00 0.65 0.79 146
groundnut 1.00 0.00 0.00 4
heat 1.00 0.00 0.00 5
hog 1.00 0.00 0.00 6
housing 1.00 0.00 0.00 4
income 1.00 0.00 0.00 7
instal-debt 1.00 0.00 0.00 1
interest 0.88 0.40 0.55 131
inventories 1.00 1.00 1.00 0
ipi 1.00 0.00 0.00 12
iron-steel 1.00 0.00 0.00 14
jet 1.00 0.00 0.00 1
jobs 1.00 0.14 0.25 21
l-cattle 1.00 0.00 0.00 2
lead 1.00 0.00 0.00 14
lei 1.00 0.00 0.00 3
linseed 1.00 1.00 1.00 0
livestock 0.67 0.08 0.15 24
lumber 1.00 0.00 0.00 6
meal-feed 1.00 0.00 0.00 17
money-fx 0.80 0.50 0.62 177
money-supply 0.88 0.41 0.56 34
naphtha 1.00 0.00 0.00 4
nat-gas 1.00 0.27 0.42 30
nickel 1.00 0.00 0.00 1
nzdlr 1.00 0.00 0.00 2
oat 1.00 0.00 0.00 4
oilseed 0.62 0.11 0.19 44
orange 1.00 0.00 0.00 11
palladium 1.00 0.00 0.00 1
palm-oil 1.00 0.22 0.36 9
palmkernel 1.00 0.00 0.00 1
pet-chem 1.00 0.00 0.00 12
platinum 1.00 0.00 0.00 7
plywood 1.00 1.00 1.00 0
pork-belly 1.00 1.00 1.00 0
potato 1.00 0.00 0.00 3
propane 1.00 0.00 0.00 3
rand 1.00 0.00 0.00 1
rape-oil 1.00 0.00 0.00 1
rapeseed 1.00 0.00 0.00 8
reserves 1.00 0.00 0.00 18
retail 1.00 0.00 0.00 2
rice 1.00 0.00 0.00 23
rubber 1.00 0.17 0.29 12
saudriyal 1.00 1.00 1.00 0
ship 0.92 0.26 0.40 89
silver 1.00 0.00 0.00 8
sorghum 1.00 0.00 0.00 8
soy-meal 1.00 0.00 0.00 12
soy-oil 1.00 0.00 0.00 8
soybean 1.00 0.16 0.27 32
stg 1.00 1.00 1.00 0
strategic-metal 1.00 0.00 0.00 11
sugar 1.00 0.60 0.75 35
sun-oil 1.00 1.00 1.00 0
sunseed 1.00 0.00 0.00 5
tapioca 1.00 1.00 1.00 0
tea 1.00 0.00 0.00 3
tin 1.00 0.00 0.00 12
trade 0.92 0.61 0.74 116
veg-oil 1.00 0.12 0.21 34
wheat 0.97 0.55 0.70 69
wool 1.00 1.00 1.00 0
wpi 1.00 0.00 0.00 10
yen 1.00 0.00 0.00 14
zinc 1.00 0.00 0.00 13
micro avg 0.97 0.64 0.77 3694
macro avg 0.98 0.25 0.29 3694
weighted avg 0.96 0.64 0.70 3694 samples avg 0.98 0.74 0.75 3694
訓練超參數
訓練過程中使用的超參數如下:
- 學習率:2e-05
- 訓練批次大小:32
- 評估批次大小:32
- 隨機種子:42
- 優化器:Adam(β1=0.9,β2=0.999,ε=1e-08)
- 學習率調度器類型:線性
- 訓練輪數:20
訓練結果
訓練損失 | 輪數 | 步數 | 驗證損失 | F1分數 | ROC AUC | 準確率 |
---|---|---|---|---|---|---|
0.1801 | 1.0 | 300 | 0.0439 | 0.3896 | 0.6210 | 0.3566 |
0.0345 | 2.0 | 600 | 0.0287 | 0.6289 | 0.7318 | 0.5954 |
0.0243 | 3.0 | 900 | 0.0219 | 0.6721 | 0.7579 | 0.6084 |
0.0178 | 4.0 | 1200 | 0.0177 | 0.7505 | 0.8128 | 0.6908 |
0.014 | 5.0 | 1500 | 0.0151 | 0.7905 | 0.8376 | 0.7278 |
0.0115 | 6.0 | 1800 | 0.0135 | 0.8132 | 0.8589 | 0.7555 |
0.0096 | 7.0 | 2100 | 0.0124 | 0.8291 | 0.8727 | 0.7725 |
0.0082 | 8.0 | 2400 | 0.0124 | 0.8335 | 0.8757 | 0.7822 |
0.0071 | 9.0 | 2700 | 0.0119 | 0.8392 | 0.8847 | 0.7883 |
0.0064 | 10.0 | 3000 | 0.0123 | 0.8339 | 0.8810 | 0.7828 |
0.0058 | 11.0 | 3300 | 0.0114 | 0.8538 | 0.8999 | 0.8047 |
0.0053 | 12.0 | 3600 | 0.0113 | 0.8525 | 0.8967 | 0.8044 |
0.0048 | 13.0 | 3900 | 0.0115 | 0.8520 | 0.8982 | 0.8029 |
0.0045 | 14.0 | 4200 | 0.0111 | 0.8566 | 0.8962 | 0.8104 |
0.0042 | 15.0 | 4500 | 0.0110 | 0.8610 | 0.9060 | 0.8165 |
0.0039 | 16.0 | 4800 | 0.0112 | 0.8583 | 0.9021 | 0.8138 |
0.0037 | 17.0 | 5100 | 0.0110 | 0.8620 | 0.9055 | 0.8196 |
0.0035 | 18.0 | 5400 | 0.0110 | 0.8629 | 0.9063 | 0.8196 |
0.0035 | 19.0 | 5700 | 0.0111 | 0.8624 | 0.9062 | 0.8180 |
0.0034 | 20.0 | 6000 | 0.0111 | 0.8626 | 0.9055 | 0.8177 |
框架版本
- Transformers 4.33.0.dev0
- Pytorch 2.0.1+cu117
- Datasets 2.14.3
- Tokenizers 0.13.3
🔧 技術細節
文檔未提供詳細技術實現細節,故跳過此章節。
📄 許可證
本模型採用Apache 2.0許可證。








