🚀 乳腺影像分類集成模型
這是一個基於篩查乳腺鉬靶影像來預測乳腺癌和乳腺密度的集成模型。該模型利用先進的技術和多分辨率策略,為乳腺疾病的診斷提供了有力支持。
🚀 快速開始
此模型是一個基於篩查乳腺鉬靶影像來預測乳腺癌和乳腺密度的集成模型。模型使用了 3 個基礎卷積神經網絡(CNN)(tf_efficientnetv2_s
骨幹網絡),並對每個提供的圖像(即 CC 和 MLO 視圖)進行推理。集成中的每個網絡使用不同的分辨率:2048 x 1024、1920 x 1280 和 1536 x 1536。最終輸出在提供的視圖和神經網絡之間進行平均。該模型也可以對單視圖(圖像)進行推理,不過性能會有所下降。
✨ 主要特性
- 多分辨率集成:模型使用 3 個基於
tf_efficientnetv2_s
骨幹網絡的 CNN,每個網絡採用不同的分辨率(2048 x 1024、1920 x 1280 和 1536 x 1536)進行推理,最後將結果平均,提高預測準確性。
- 多數據集訓練:先在 CBIS - DDSM 數據集上預訓練,再在 RSNA 篩查乳腺鉬靶乳腺癌檢測挑戰 的數據上進一步訓練,增強模型泛化能力。
- 支持單視圖和批量推理:既可以對單視圖圖像進行推理,也支持批量推理,滿足不同的使用場景。
📦 安裝指南
文檔未提及具體安裝步驟,故跳過此章節。
💻 使用示例
基礎用法
import cv2
import torch
from transformers import AutoModel
def crop_mammo(img, model, device):
img_shape = torch.tensor([img.shape[:2]]).to(device)
x = model.preprocess(img)
x = torch.from_numpy(x).expand(1, 1, -1, -1).float().to(device)
with torch.inference_mode():
coords = model(x, img_shape)
coords = coords[0].cpu().numpy()
x, y, w, h = coords
return img[y: y + h, x: x + w]
device = "cuda:0"
crop_model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True)
crop_model = crop_model.eval().to(device)
model = AutoModel.from_pretrained("ianpan/mammoscreen", trust_remote_code=True)
model = model.eval().to(device)
cc_img = cv2.imread("mammo_cc.png", cv2.IMREAD_GRAYSCALE)
mlo_img = cv2.imread("mammo_mlo.png", cv2.IMREAD_GRAYSCALE)
cc_img = crop_mammo(cc_img, crop_model, device)
mlo_img = crop_mammo(mlo_img, crop_model, device)
with torch.inference_mode():
output = model({"cc": cc_img, "mlo": mlo_img}, device=device)
高級用法
單獨訪問每個神經網絡
input_dict = model.net0.preprocess({"cc": cc_img, "mlo": mlo_img}, device=device)
with torch.inference_mode():
out = model.net0(input_dict)
批量推理
cc_images = ["rt_pt1_cc.png", "lt_pt1_cc.png", "rt_pt2_cc.png", "lt_pt2_cc.png"]
mlo_images = ["rt_pt1_mlo.png", "lt_pt1_mlo.png", "rt_pt2_mlo.png", "lt_pt2_mlo.png"]
cc_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in cc_images]
mlo_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in mlo_images]
cc_images = [crop_mammo(_, crop_model, device) for _ in cc_images]
mlo_images = [crop_mammo(_, crop_model, device) for _ in mlo_images]
input_dict = [{"cc": cc_img, "mlo": mlo_img} for cc_img, mlo_img in zip(cc_images, mlo_images)]
with torch.inference_mode():
output = model(input_dict, device=device)
📚 詳細文檔
模型訓練
- 預訓練:先在 CBIS - DDSM 數據集上預訓練,該數據集包含膠片乳腺鉬靶研究及良性和惡性腫塊與鈣化的 ROI 註釋。
- 進一步訓練:在 RSNA 篩查乳腺鉬靶乳腺癌檢測挑戰 的數據上進一步訓練,數據按 80%/10%/10% 劃分為訓練集、驗證集和測試集,評估在 10% 的測試集上進行,此過程重複 3 次以更好評估模型性能,提供的權重來自第一次數據劃分。
- 訓練技巧:訓練過程中使用了指數移動平均,提高了模型性能。
模型輸出
output
是一個字典,包含 cancer
和 density
兩個鍵。output['cancer']
是形狀為 (N, 1) 的張量,output['density']
是形狀為 (N, 4) 的張量。
- 若要獲取預測的密度類別,可使用
output['density'].argmax(1)
。若只提供單個研究,則 N = 1。
數據預處理
- 模型在
forward
函數內對數據進行預處理。若單獨訪問每個神經網絡(如 model.net{i}
),則需在 forward
函數外進行預處理。
圖像格式轉換
若將 DICOM 圖像轉換為 8 位 PNG/JPEG 圖像,需對像素值應用查找表,可使用 pydicom.pixels.apply_voi_lut
。若安裝了 pydicom
,可使用 model.load_image_from_dicom
直接加載 DICOM 圖像。
🔧 技術細節
模型架構
模型是一個集成模型,使用 3 個基於 tf_efficientnetv2_s
骨幹網絡的 CNN,每個網絡使用不同分辨率進行推理,最後將結果平均。
評估指標
主要評估指標是受試者工作特徵曲線下面積(AUC/AUROC),3 次數據劃分的平均 AUC 及標準差如下:
Split 1: 0.9464
Split 2: 0.9467
Split 3: 0.9422
Mean (std.): 0.9451 (0.002)
靈敏度和特異度
模型在不同靈敏度下的特異度如下(3 次劃分的平均值):
Sensitivity: 98.1%, Specificity: 65.4% +/- 7.2%, Threshold: 0.0072 +/- 0.0021
Sensitivity: 94.3%, Specificity: 78.7% +/- 0.9%, Threshold: 0.0127 +/- 0.0011
Sensitivity: 90.5%, Specificity: 84.8% +/- 2.7%, Threshold: 0.0184 +/- 0.0027
📄 許可證
本模型使用 apache - 2.0
許可證。
⚠️ 重要提示
模型是使用裁剪後的圖像進行訓練的,因此建議在推理前對圖像進行裁剪。裁剪模型可從 此處 獲取。
💡 使用建議
若單獨訪問每個神經網絡(如 model.net{i}
),需在 forward
函數外進行預處理。若要進行批量推理,需構建每個乳腺的字典並將字典列表傳遞給模型。