🚀 胸部X光图像分类器
本仓库包含一个经过微调的视觉变换器(Vision Transformer,ViT) 模型,用于对胸部X光图像进行分类,该模型使用了CheXpert数据集。此模型针对从胸部X光片分类各种肺部疾病的任务进行了微调,在区分不同病症方面取得了令人瞩目的准确率。
🚀 快速开始
本模型可用于胸部X光图像的疾病分类。你可以直接从Hugging Face的模型中心加载该模型,并输入胸部X光图像进行推理。若要在自己的数据集上微调该模型,可参考本仓库中的说明,将代码适配到你的数据集和训练配置中。
✨ 主要特性
- 基于视觉变换器(ViT) 架构,通过注意力机制高效提取特征,擅长处理基于图像的任务。
- 在CheXpert数据集上进行训练,该数据集包含标注的胸部X光图像,可用于检测肺炎、心脏扩大等疾病。
- 在训练过程中显著提高了准确率,能够很好地泛化到未见过的胸部X光图像。
📦 安装指南
文档未提及具体安装步骤,可参考Hugging Face的transformers
库安装说明进行安装。
💻 使用示例
基础用法
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
processor = AutoImageProcessor.from_pretrained("codewithdark/vit-chest-xray")
model = AutoModelForImageClassification.from_pretrained("codewithdark/vit-chest-xray")
label_columns = ['Cardiomegaly', 'Edema', 'Consolidation', 'Pneumonia', 'No Finding']
image_path = "/content/images.jpeg"
image = Image.open(image_path)
if image.mode != 'RGB':
image = image.convert('RGB')
print("Image converted to RGB.")
inputs = processor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = torch.argmax(logits, dim=-1).item()
predicted_class_label = label_columns[predicted_class_idx]
print(f"Predicted Class Index: {predicted_class_idx}")
print(f"Predicted Class Label: {predicted_class_label}")
'''
Output :
Predicted Class Index: 4
Predicted Class Label: No Finding
'''
高级用法
文档未提及高级用法相关代码示例。
📚 详细文档
模型概述
经过微调的模型基于视觉变换器(ViT) 架构,该架构通过利用注意力机制进行高效特征提取,在处理基于图像的任务方面表现出色。该模型在CheXpert数据集上进行训练,该数据集包含标注的胸部X光图像,用于检测肺炎、心脏扩大等疾病。
性能
- 最终验证准确率:98.46%
- 最终训练损失:0.1069
- 最终验证损失:0.0980
该模型在训练过程中显著提高了准确率,证明了其能够很好地泛化到未见过的胸部X光图像。
数据集
用于微调模型的数据集是CheXpert数据集,其中包括来自不同患者的胸部X光图像,并带有多标签注释。数据包括每个患者胸部的正位和侧位视图,并标注了各种肺部疾病的标签。
有关数据集的更多详细信息,请访问 CheXpert官方网站。
训练细节
模型使用以下设置进行微调:
- 优化器:AdamW
- 学习率:3e - 5
- 批量大小:32
- 训练轮数:10
- 损失函数:带Logits的二元交叉熵
- 精度:混合精度(通过
torch.amp
)
🔧 技术细节
本模型基于视觉变换器(ViT)架构,利用注意力机制进行特征提取。在CheXpert数据集上进行训练,通过AdamW优化器和带Logits的二元交叉熵损失函数进行微调。训练过程中使用混合精度以提高效率。
📄 许可证
本模型采用MIT许可证。有关更多详细信息,请参阅 LICENSE。
致谢
编码愉快!🚀