đ Transformers
An ensemble model for predicting breast cancer and breast density based on screening mammography.
đ Quick Start
This is an ensemble model designed to predict breast cancer and breast density using screening mammography. It employs 3 basic CNNs with a tf_efficientnetv2_s
backbone. The model conducts inference on each supplied image (CC and MLO views). Each neural network in the ensemble uses a different resolution: 2048 x 1024, 1920 x 1280, and 1536 x 1536. The final outputs are averaged across the provided views and neural nets. It can also perform inference on a single view (image), though the performance will decline.
⨠Features
- Pretraining: A hybrid classification - segmentation model was first pretrained on the Curated Breast Imaging Subset of Digital Database for Screening Mammography (CBIS-DDSM).
- Further Training: The resultant model was further trained on data from the RSNA Screening Mammography Breast Cancer Detection challenge.
- Exponential Moving Averaging: Used during training to enhance performance.
- Cropping Recommendation: The model was trained with cropped images, so it's recommended to crop images before inference. A cropping model is available at https://huggingface.co/ianpan/mammo-crop.
- Evaluation Metric: The primary evaluation metric is the area under the receiver operating characteristic curve (AUC/AUROC).
đģ Usage Examples
Basic Usage
import cv2
import torch
from transformers import AutoModel
def crop_mammo(img, model, device):
img_shape = torch.tensor([img.shape[:2]]).to(device)
x = model.preprocess(img)
x = torch.from_numpy(x).expand(1, 1, -1, -1).float().to(device)
with torch.inference_mode():
coords = model(x, img_shape)
coords = coords[0].cpu().numpy()
x, y, w, h = coords
return img[y: y + h, x: x + w]
device = "cuda:0"
crop_model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True)
crop_model = crop_model.eval().to(device)
model = AutoModel.from_pretrained("ianpan/mammoscreen", trust_remote_code=True)
model = model.eval().to(device)
cc_img = cv2.imread("mammo_cc.png", cv2.IMREAD_GRAYSCALE)
mlo_img = cv2.imread("mammo_mlo.png", cv2.IMREAD_GRAYSCALE)
cc_img = crop_mammo(cc_img, crop_model, device)
mlo_img = crop_mammo(mlo_img, crop_model, device)
with torch.inference_mode():
output = model({"cc": cc_img, "mlo": mlo_img}, device=device)
Advanced Usage
Accessing Each Neural Net Separately
input_dict = model.net0.preprocess({"cc": cc_img, "mlo": mlo_img}, device=device)
with torch.inference_mode():
out = model.net0(input_dict)
Batch Inference
cc_images = ["rt_pt1_cc.png", "lt_pt1_cc.png", "rt_pt2_cc.png", "lt_pt2_cc.png"]
mlo_images = ["rt_pt1_mlo.png", "lt_pt1_mlo.png", "rt_pt2_mlo.png", "lt_pt2_mlo.png"]
cc_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in cc_images]
mlo_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in mlo_images]
cc_images = [crop_mammo(_, crop_model, device) for _ in cc_images]
mlo_images = [crop_mammo(_, crop_model, device) for _ in mlo_images]
input_dict = [{"cc": cc_img, "mlo": mlo_img} for cc_img, mlo_img in zip(cc_images, mlo_images)]
with torch.inference_mode():
output = model(input_dict, device=device)
đ Documentation
Model Output
The model preprocesses the data within the forward
function into the necessary format. output
is a dictionary containing two keys: cancer
and density
. output['cancer']
is a tensor of shape (N, 1) and output['density']
is a tensor of shape (N, 4). If you want the predicted density class, take the argmax: output['density'].argmax(1)
. If only a single study is provided, then N = 1.
Image Conversion
If you are converting images from DICOM to 8 - bit PNG/JPEG, it is important to apply the lookup table to the pixel values, which can be done using pydicom.pixels.apply_voi_lut
. If you have pydicom
installed, you can also load a DICOM image directly using img = model.load_image_from_dicom(path_to_dicom)
.
đ§ Technical Details
Model Architecture
The ensemble model uses 3 basic CNNs with a tf_efficientnetv2_s
backbone. Each net in the ensemble uses a different resolution: 2048 x 1024, 1920 x 1280, and 1536 x 1536. The final outputs are averaged across the provided views and the neural nets.
Training Data
- Pretraining: The model was first pretrained on the Curated Breast Imaging Subset of Digital Database for Screening Mammography (CBIS-DDSM).
- Further Training: Further trained on data from the RSNA Screening Mammography Breast Cancer Detection challenge. The data was split into 80%/10%/10% train/val/test. Evaluation was performed on the 10% holdout test split, and this procedure was repeated 3 separate times.
Evaluation Metrics
The primary evaluation metric is the area under the receiver operating characteristic curve (AUC/AUROC). The average and standard deviation across the 3 splits are as follows:
Split 1: 0.9464
Split 2: 0.9467
Split 3: 0.9422
Mean (std.): 0.9451 (0.002)
The model also calculates the specificity at varying sensitivities:
Sensitivity: 98.1%, Specificity: 65.4% +/- 7.2%, Threshold: 0.0072 +/- 0.0021
Sensitivity: 94.3%, Specificity: 78.7% +/- 0.9%, Threshold: 0.0127 +/- 0.0011
Sensitivity: 90.5%, Specificity: 84.8% +/- 2.7%, Threshold: 0.0184 +/- 0.0027
đ License
This project is licensed under the Apache-2.0 license.
đĻ Model Information
Property |
Details |
Library Name |
transformers |
Tags |
mammography, cancer, breast_cancer, radiology, breast_density |
License |
apache-2.0 |
Base Model |
timm/tf_efficientnetv2_s.in21k_ft_in1k |
Pipeline Tag |
image-classification |
â ī¸ Important Note
- The model was trained using cropped images, so it is recommended to crop the image prior to inference. A cropping model is provided at https://huggingface.co/ianpan/mammo-crop.
- If you access each neural net separately using
model.net{i}
, you must apply the preprocessing outside of the forward
function.