🚀 胸部X光肺炎分類ビジョントランスフォーマーモデル
このモデルは胸部のX光画像の肺炎分類に使用され、事前学習済みモデルを微調整して作成されました。評価セットで優れた性能を発揮し、肺炎と正常な胸部のX光画像を効率的かつ正確に識別し、医療診断に強力な支援を提供します。
🚀 クイックスタート
このモデルは google/vit-base-patch16-224-in21k を胸部X光分類データセットで微調整したバージョンです。
評価セットでは以下の結果を得ています:
💻 使用例
基本的な使用法
from transformers import pipeline
classifier = pipeline(model="lxyuan/vit-xray-pneumonia-classification")
classifier("https://d2jx2rerrg6sh3.cloudfront.net/image-handler/ts/20200618040600/ri/650/picture/2020/6/shutterstock_786937069.jpg")
>>>
[{'score': 0.990334689617157, 'label': 'PNEUMONIA'},
{'score': 0.009665317833423615, 'label': 'NORMAL'}]
📚 ドキュメント
学習過程
ノートブックのリンク:クリックして表示
学習ハイパーパラメータ
学習過程では以下のハイパーパラメータが使用されました:
- 学習率:5e-05
- 学習バッチサイズ:16
- 評価バッチサイズ:16
- 乱数シード:42
- 勾配累積ステップ数:4
- 総学習バッチサイズ:64
- オプティマイザ:Adam(β1 = 0.9,β2 = 0.999,ε = 1e-08)
- 学習率スケジューラの種類:線形
- 学習率スケジューラのウォームアップ比率:0.1
- 学習エポック数:15
from transformers import EarlyStoppingCallback
training_args = TrainingArguments(
output_dir="vit-xray-pneumonia-classification",
remove_unused_columns=False,
evaluation_strategy="epoch",
save_strategy="epoch",
logging_strategy="epoch",
learning_rate=5e-5,
per_device_train_batch_size=16,
gradient_accumulation_steps=4,
per_device_eval_batch_size=16,
num_train_epochs=15,
save_total_limit=2,
warmup_ratio=0.1,
load_best_model_at_end=True,
metric_for_best_model="eval_loss",
greater_is_better=False,
fp16=True,
push_to_hub=True,
report_to="tensorboard"
)
early_stopping = EarlyStoppingCallback(early_stopping_patience=3)
trainer = Trainer(
model=model,
args=training_args,
data_collator=data_collator,
train_dataset=train_ds,
eval_dataset=val_ds,
tokenizer=processor,
compute_metrics=compute_metrics,
callbacks=[early_stopping],
)
学習結果
学習損失 |
エポック数 |
ステップ数 |
検証損失 |
正解率 |
0.5152 |
0.99 |
63 |
0.2507 |
0.9245 |
0.2334 |
1.99 |
127 |
0.1766 |
0.9382 |
0.1647 |
3.0 |
191 |
0.1218 |
0.9588 |
0.144 |
4.0 |
255 |
0.1222 |
0.9502 |
0.1348 |
4.99 |
318 |
0.1293 |
0.9571 |
0.1276 |
5.99 |
382 |
0.1000 |
0.9665 |
0.1175 |
7.0 |
446 |
0.1177 |
0.9502 |
0.109 |
8.0 |
510 |
0.1079 |
0.9665 |
0.0914 |
8.99 |
573 |
0.0804 |
0.9717 |
0.0872 |
9.99 |
637 |
0.0800 |
0.9717 |
0.0804 |
11.0 |
701 |
0.0862 |
0.9682 |
0.0935 |
12.0 |
765 |
0.0883 |
0.9657 |
0.0686 |
12.99 |
828 |
0.0868 |
0.9742 |
フレームワークのバージョン
- Transformers 4.30.2
- Pytorch 1.9.0+cu102
- Datasets 2.12.0
- Tokenizers 0.13.3
📄 ライセンス
このモデルは Apache-2.0 ライセンスを採用しています。
属性 |
詳細 |
モデルタイプ |
画像分類 |
学習データ |
chest-xray-classification、keremberke/chest-xray-classification |
評価指標 |
正解率 |
ベースモデル |
google/vit-base-patch16-224-in21k |