🚀 乳腺影像分类集成模型
这是一个基于筛查乳腺钼靶影像来预测乳腺癌和乳腺密度的集成模型。该模型利用先进的技术和多分辨率策略,为乳腺疾病的诊断提供了有力支持。
🚀 快速开始
此模型是一个基于筛查乳腺钼靶影像来预测乳腺癌和乳腺密度的集成模型。模型使用了 3 个基础卷积神经网络(CNN)(tf_efficientnetv2_s
骨干网络),并对每个提供的图像(即 CC 和 MLO 视图)进行推理。集成中的每个网络使用不同的分辨率:2048 x 1024、1920 x 1280 和 1536 x 1536。最终输出在提供的视图和神经网络之间进行平均。该模型也可以对单视图(图像)进行推理,不过性能会有所下降。
✨ 主要特性
- 多分辨率集成:模型使用 3 个基于
tf_efficientnetv2_s
骨干网络的 CNN,每个网络采用不同的分辨率(2048 x 1024、1920 x 1280 和 1536 x 1536)进行推理,最后将结果平均,提高预测准确性。
- 多数据集训练:先在 CBIS - DDSM 数据集上预训练,再在 RSNA 筛查乳腺钼靶乳腺癌检测挑战 的数据上进一步训练,增强模型泛化能力。
- 支持单视图和批量推理:既可以对单视图图像进行推理,也支持批量推理,满足不同的使用场景。
📦 安装指南
文档未提及具体安装步骤,故跳过此章节。
💻 使用示例
基础用法
import cv2
import torch
from transformers import AutoModel
def crop_mammo(img, model, device):
img_shape = torch.tensor([img.shape[:2]]).to(device)
x = model.preprocess(img)
x = torch.from_numpy(x).expand(1, 1, -1, -1).float().to(device)
with torch.inference_mode():
coords = model(x, img_shape)
coords = coords[0].cpu().numpy()
x, y, w, h = coords
return img[y: y + h, x: x + w]
device = "cuda:0"
crop_model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True)
crop_model = crop_model.eval().to(device)
model = AutoModel.from_pretrained("ianpan/mammoscreen", trust_remote_code=True)
model = model.eval().to(device)
cc_img = cv2.imread("mammo_cc.png", cv2.IMREAD_GRAYSCALE)
mlo_img = cv2.imread("mammo_mlo.png", cv2.IMREAD_GRAYSCALE)
cc_img = crop_mammo(cc_img, crop_model, device)
mlo_img = crop_mammo(mlo_img, crop_model, device)
with torch.inference_mode():
output = model({"cc": cc_img, "mlo": mlo_img}, device=device)
高级用法
单独访问每个神经网络
input_dict = model.net0.preprocess({"cc": cc_img, "mlo": mlo_img}, device=device)
with torch.inference_mode():
out = model.net0(input_dict)
批量推理
cc_images = ["rt_pt1_cc.png", "lt_pt1_cc.png", "rt_pt2_cc.png", "lt_pt2_cc.png"]
mlo_images = ["rt_pt1_mlo.png", "lt_pt1_mlo.png", "rt_pt2_mlo.png", "lt_pt2_mlo.png"]
cc_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in cc_images]
mlo_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in mlo_images]
cc_images = [crop_mammo(_, crop_model, device) for _ in cc_images]
mlo_images = [crop_mammo(_, crop_model, device) for _ in mlo_images]
input_dict = [{"cc": cc_img, "mlo": mlo_img} for cc_img, mlo_img in zip(cc_images, mlo_images)]
with torch.inference_mode():
output = model(input_dict, device=device)
📚 详细文档
模型训练
- 预训练:先在 CBIS - DDSM 数据集上预训练,该数据集包含胶片乳腺钼靶研究及良性和恶性肿块与钙化的 ROI 注释。
- 进一步训练:在 RSNA 筛查乳腺钼靶乳腺癌检测挑战 的数据上进一步训练,数据按 80%/10%/10% 划分为训练集、验证集和测试集,评估在 10% 的测试集上进行,此过程重复 3 次以更好评估模型性能,提供的权重来自第一次数据划分。
- 训练技巧:训练过程中使用了指数移动平均,提高了模型性能。
模型输出
output
是一个字典,包含 cancer
和 density
两个键。output['cancer']
是形状为 (N, 1) 的张量,output['density']
是形状为 (N, 4) 的张量。
- 若要获取预测的密度类别,可使用
output['density'].argmax(1)
。若只提供单个研究,则 N = 1。
数据预处理
- 模型在
forward
函数内对数据进行预处理。若单独访问每个神经网络(如 model.net{i}
),则需在 forward
函数外进行预处理。
图像格式转换
若将 DICOM 图像转换为 8 位 PNG/JPEG 图像,需对像素值应用查找表,可使用 pydicom.pixels.apply_voi_lut
。若安装了 pydicom
,可使用 model.load_image_from_dicom
直接加载 DICOM 图像。
🔧 技术细节
模型架构
模型是一个集成模型,使用 3 个基于 tf_efficientnetv2_s
骨干网络的 CNN,每个网络使用不同分辨率进行推理,最后将结果平均。
评估指标
主要评估指标是受试者工作特征曲线下面积(AUC/AUROC),3 次数据划分的平均 AUC 及标准差如下:
Split 1: 0.9464
Split 2: 0.9467
Split 3: 0.9422
Mean (std.): 0.9451 (0.002)
灵敏度和特异度
模型在不同灵敏度下的特异度如下(3 次划分的平均值):
Sensitivity: 98.1%, Specificity: 65.4% +/- 7.2%, Threshold: 0.0072 +/- 0.0021
Sensitivity: 94.3%, Specificity: 78.7% +/- 0.9%, Threshold: 0.0127 +/- 0.0011
Sensitivity: 90.5%, Specificity: 84.8% +/- 2.7%, Threshold: 0.0184 +/- 0.0027
📄 许可证
本模型使用 apache - 2.0
许可证。
⚠️ 重要提示
模型是使用裁剪后的图像进行训练的,因此建议在推理前对图像进行裁剪。裁剪模型可从 此处 获取。
💡 使用建议
若单独访问每个神经网络(如 model.net{i}
),需在 forward
函数外进行预处理。若要进行批量推理,需构建每个乳腺的字典并将字典列表传递给模型。