đ Chest X-ray Segmentation and Classification Model
This model is designed to perform both segmentation and classification tasks on chest radiographs (X-rays). It offers valuable insights into the structure of the chest, including the lungs and heart, and can predict important patient information such as age, sex, and the view of the X-ray.
⨠Features
- Dual Functionality: Performs both segmentation and classification on chest radiographs.
- Accurate Segmentation: Segments the right lung, left lung, and heart with high Dice similarity coefficients.
- Comprehensive Classification: Predicts the chest X-ray view (AP, PA, lateral), patient age, and patient sex.
- Diverse Datasets: Trained on a combination of the CheXpert (small version) and NIH Chest X-ray datasets.
đĻ Installation
To use this model, you need to have the transformers
library installed. You can install it using the following command:
pip install transformers
đģ Usage Examples
Basic Usage
import cv2
import torch
from transformers import AutoModel
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True)
model = model.eval().to(device)
img = cv2.imread(..., 0)
x = model.preprocess(img)
x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0)
x = x.float()
with torch.inference_mode():
out = model(x.to(device))
Advanced Usage
You can use the segmentation mask to crop the region containing the lungs from the rest of the X-ray. You can also calculate the cardiothoracic ratio (CTR) using the following function:
import numpy as np
def calculate_ctr(mask):
lungs = np.zeros_like(mask)
lungs[mask == 1] = 1
lungs[mask == 2] = 1
heart = (mask == 3).astype("int")
y, x = np.stack(np.where(lungs == 1))
lung_min = x.min()
lung_max = x.max()
y, x = np.stack(np.where(heart == 1))
heart_min = x.min()
heart_max = x.max()
lung_range = lung_max - lung_min
heart_range = heart_max - heart_min
return heart_range / lung_range
If you have pydicom
installed, you can also load a DICOM image directly:
img = model.load_image_from_dicom(path_to_dicom)
đ Documentation
Model Architecture
The model uses a tf_efficientnetv2_s
backbone with a U-Net decoder for segmentation and a linear layer for classification.
Training Data
The model was trained on a combination of the CheXpert (small version) and NIH Chest X-ray datasets. Segmentation masks were obtained from the CheXmask dataset (paper). The final dataset comprised 335,516 images from 96,385 patients and was split into 80% training/20% validation.
Validation Performance
Segmentation (Dice similarity coefficient):
Right Lung: 0.957
Left Lung: 0.948
Heart: 0.943
Age Prediction:
Mean Absolute Error: 5.25 years
Classification:
View (AP, PA, lateral): 99.42% accuracy
Female: 0.999 AUC
Output Format
The output of the model is a dictionary which contains 4 keys:
mask
: Has 3 channels containing the segmentation masks. Take the argmax over the channel dimension to create a single image mask (i.e., out["mask"].argmax(1)
): 1 = right lung, 2 = left lung, 3 = heart.
age
: In years.
view
: With 3 classes for each possible view. Take the argmax to select the predicted view (i.e., out["view"].argmax(1)
): 0 = AP, 1 = PA, 2 = lateral.
female
: Binarize with out["female"] >= 0.5
.
đ§ Technical Details
- Model Type: Image Segmentation and Classification
- Backbone:
tf_efficientnetv2_s
- Decoder: U-Net
- Classifier: Linear Layer
- Training Data: CheXpert (small version), NIH Chest X-ray, CheXmask
- Dataset Split: 80% training / 20% validation
- View Classifier Training: Only on CheXpert images to avoid bias
đ License
This model is for demonstration and research purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes any and all responsibility regarding their own use of this model and its outputs.
â ī¸ Important Note
This model is for demonstration and research purposes only and has NOT been approved by any regulatory agency for clinical use. The user assumes any and all responsibility regarding their own use of this model and its outputs.