🚀 胸部X光肺炎分類視覺變換器模型
本模型用於胸部X光圖像的肺炎分類,基於預訓練模型微調而來,在評估集上表現出色,能高效準確地識別肺炎與正常的胸部X光影像,為醫療診斷提供有力支持。
🚀 快速開始
本模型是 google/vit-base-patch16-224-in21k 在胸部X光分類數據集上的微調版本。
它在評估集上取得了以下結果:
💻 使用示例
基礎用法
from transformers import pipeline
classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification")
classifier("https://d2jx2rerrg6sh3.cloudfront.net/image-handler/ts/20200618040600/ri/650/picture/2020/6/shutterstock_786937069.jpg")
>>>
[{'score': 0.990334689617157, 'label': 'PNEUMONIA'},
{'score': 0.009665317833423615, 'label': 'NORMAL'}]
📚 詳細文檔
訓練過程
筆記本鏈接:點擊查看
訓練超參數
訓練過程中使用了以下超參數:
- 學習率:5e-05
- 訓練批次大小:16
- 評估批次大小:16
- 隨機種子:42
- 梯度累積步數:4
- 總訓練批次大小:64
- 優化器:Adam(β1 = 0.9,β2 = 0.999,ε = 1e-08)
- 學習率調度器類型:線性
- 學習率調度器預熱比例:0.1
- 訓練輪數:15
from transformers import EarlyStoppingCallback
training_args = TrainingArguments(
output_dir="vit-xray-pneumonia-classification",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
per_device_eval_batch_size=16,
num_train_epochs=15,
save_total_limit=2,
warmup_ratio=0.1,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=True,
push_to_hub=True,
report_to="tensorboard"
)
early_stopping = EarlyStoppingCallback(early_stopping_patience=3)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_ds,
eval_dataset=val_ds,
tokenizer=processor,
compute_metrics=compute_metrics,
callbacks=[early_stopping],
)
訓練結果
訓練損失 |
輪數 |
步數 |
驗證損失 |
準確率 |
0.5152 |
0.99 |
63 |
0.2507 |
0.9245 |
0.2334 |
1.99 |
127 |
0.1766 |
0.9382 |
0.1647 |
3.0 |
191 |
0.1218 |
0.9588 |
0.144 |
4.0 |
255 |
0.1222 |
0.9502 |
0.1348 |
4.99 |
318 |
0.1293 |
0.9571 |
0.1276 |
5.99 |
382 |
0.1000 |
0.9665 |
0.1175 |
7.0 |
446 |
0.1177 |
0.9502 |
0.109 |
8.0 |
510 |
0.1079 |
0.9665 |
0.0914 |
8.99 |
573 |
0.0804 |
0.9717 |
0.0872 |
9.99 |
637 |
0.0800 |
0.9717 |
0.0804 |
11.0 |
701 |
0.0862 |
0.9682 |
0.0935 |
12.0 |
765 |
0.0883 |
0.9657 |
0.0686 |
12.99 |
828 |
0.0868 |
0.9742 |
框架版本
- Transformers 4.30.2
- Pytorch 1.9.0+cu102
- Datasets 2.12.0
- Tokenizers 0.13.3
📄 許可證
本模型採用 Apache-2.0 許可證。
屬性 |
詳情 |
模型類型 |
圖像分類 |
訓練數據 |
chest-xray-classification、keremberke/chest-xray-classification |
評估指標 |
準確率 |
基礎模型 |
google/vit-base-patch16-224-in21k |