🚀 胸部X光圖像分類器
本倉庫包含一個經過微調的視覺變換器(Vision Transformer,ViT) 模型,用於對胸部X光圖像進行分類,該模型使用了CheXpert數據集。此模型針對從胸部X光片分類各種肺部疾病的任務進行了微調,在區分不同病症方面取得了令人矚目的準確率。
🚀 快速開始
本模型可用於胸部X光圖像的疾病分類。你可以直接從Hugging Face的模型中心加載該模型,並輸入胸部X光圖像進行推理。若要在自己的數據集上微調該模型,可參考本倉庫中的說明,將代碼適配到你的數據集和訓練配置中。
✨ 主要特性
- 基於視覺變換器(ViT) 架構,通過注意力機制高效提取特徵,擅長處理基於圖像的任務。
- 在CheXpert數據集上進行訓練,該數據集包含標註的胸部X光圖像,可用於檢測肺炎、心臟擴大等疾病。
- 在訓練過程中顯著提高了準確率,能夠很好地泛化到未見過的胸部X光圖像。
📦 安裝指南
文檔未提及具體安裝步驟,可參考Hugging Face的transformers
庫安裝說明進行安裝。
💻 使用示例
基礎用法
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
processor = AutoImageProcessor.from_pretrained("codewithdark/vit-chest-xray")
model = AutoModelForImageClassification.from_pretrained("codewithdark/vit-chest-xray")
label_columns = ['Cardiomegaly', 'Edema', 'Consolidation', 'Pneumonia', 'No Finding']
image_path = "/content/images.jpeg"
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert('RGB')
print("Image converted to RGB.")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = torch.argmax(logits, dim=-1).item()
predicted_class_label = label_columns[predicted_class_idx]
print(f"Predicted Class Index: {predicted_class_idx}")
print(f"Predicted Class Label: {predicted_class_label}")
'''
Output :
Predicted Class Index: 4
Predicted Class Label: No Finding
'''
高級用法
文檔未提及高級用法相關代碼示例。
📚 詳細文檔
模型概述
經過微調的模型基於視覺變換器(ViT) 架構,該架構通過利用注意力機制進行高效特徵提取,在處理基於圖像的任務方面表現出色。該模型在CheXpert數據集上進行訓練,該數據集包含標註的胸部X光圖像,用於檢測肺炎、心臟擴大等疾病。
性能
- 最終驗證準確率:98.46%
- 最終訓練損失:0.1069
- 最終驗證損失:0.0980
該模型在訓練過程中顯著提高了準確率,證明了其能夠很好地泛化到未見過的胸部X光圖像。
數據集
用於微調模型的數據集是CheXpert數據集,其中包括來自不同患者的胸部X光圖像,並帶有多標籤註釋。數據包括每個患者胸部的正位和側位視圖,並標註了各種肺部疾病的標籤。
有關數據集的更多詳細信息,請訪問 CheXpert官方網站。
訓練細節
模型使用以下設置進行微調:
- 優化器:AdamW
- 學習率:3e - 5
- 批量大小:32
- 訓練輪數:10
- 損失函數:帶Logits的二元交叉熵
- 精度:混合精度(通過
torch.amp
)
🔧 技術細節
本模型基於視覺變換器(ViT)架構,利用注意力機制進行特徵提取。在CheXpert數據集上進行訓練,通過AdamW優化器和帶Logits的二元交叉熵損失函數進行微調。訓練過程中使用混合精度以提高效率。
📄 許可證
本模型採用MIT許可證。有關更多詳細信息,請參閱 LICENSE。
致謝
編碼愉快!🚀