🚀 SegVol
SegVolは、体積医用画像セグメンテーションのための汎用的で対話型のモデルです。このモデルは、点、ボックス、およびテキストプロンプトを受け取り、体積セグメンテーションを出力します。90kの未ラベル付きコンピュータ断層撮影(CT)ボリュームと6kのラベル付きCTで訓練することで、この基盤モデルは200以上の解剖学的カテゴリのセグメンテーションをサポートします。
論文 と コード が公開されています。
キーワード: 3D医用SAM、体積画像セグメンテーション

🚀 クイックスタート
必要条件
conda create -n segvol_transformers python=3.8
conda activate segvol_transformers
pytorch v1.11.0 (またはそれ以上のバージョン)が必要です。以下のコマンドを使用して主要な要件をインストールしてください。
pip install 'monai[all]==0.9.0'
pip install einops==0.6.1
pip install transformers==4.18.0
pip install matplotlib
テストスクリプト
from transformers import AutoModel, AutoTokenizer
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
clip_tokenizer = AutoTokenizer.from_pretrained("yuxindu/segvol")
model = AutoModel.from_pretrained("yuxindu/segvol", trust_remote_code=True, test_mode=True)
model.model.text_encoder.tokenizer = clip_tokenizer
model.eval()
model.to(device)
print('model load done')
ct_path = 'path/to/Case_image_00001_0000.nii.gz'
gt_path = 'path/to/Case_label_00001.nii.gz'
categories = ["liver", "kidney", "spleen", "pancreas"]
ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
data_item = model.processor.zoom_transform(ct_npy, gt_npy)
data_item['image'], data_item['label'], data_item['zoom_out_image'], data_item['zoom_out_label'] = \
data_item['image'].unsqueeze(0).to(device), data_item['label'].unsqueeze(0).to(device), data_item['zoom_out_image'].unsqueeze(0).to(device), data_item['zoom_out_label'].unsqueeze(0).to(device)
cls_idx = 0
text_prompt = [categories[cls_idx]]
point_prompt, point_prompt_map = model.processor.point_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device)
bbox_prompt, bbox_prompt_map = model.processor.bbox_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device)
print('prompt done')
logits_mask = model.forward_test(image=data_item['image'],
zoomed_image=data_item['zoom_out_image'],
bbox_prompt_group=[bbox_prompt, bbox_prompt_map],
text_prompt=text_prompt,
use_zoom=False
)
dice = model.processor.dice_score(logits_mask[0][0], data_item['label'][0][cls_idx])
print(dice)
save_path='./Case_preds_00001.nii.gz'
model.processor.save_preds(ct_path, save_path, logits_mask[0][0],
start_coord=data_item['foreground_start_coord'],
end_coord=data_item['foreground_end_coord'])
print('done')
訓練スクリプト
from transformers import AutoModel, AutoTokenizer
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
clip_tokenizer = AutoTokenizer.from_pretrained("yuxindu/segvol")
model = AutoModel.from_pretrained("yuxindu/segvol", trust_remote_code=True, test_mode=False)
model.model.text_encoder.tokenizer = clip_tokenizer
model.train()
model.to(device)
print('model load done')
ct_path = 'path/to/Case_image_00001_0000.nii.gz'
gt_path = 'path/to/Case_label_00001.nii.gz'
categories = ["liver", "kidney", "spleen", "pancreas"]
ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
data_item = model.processor.train_transform(ct_npy, gt_npy)
image, gt3D = data_item["image"].unsqueeze(0).to(device), data_item["label"].unsqueeze(0).to(device)
loss_step_avg = 0
for cls_idx in range(len(categories)):
organs_cls = categories[cls_idx]
labels_cls = gt3D[:, cls_idx]
print(image.shape, organs_cls, labels_cls.shape)
loss = model.forward_train(image, train_organs=organs_cls, train_labels=labels_cls)
loss_step_avg += loss.item()
loss.backward()
loss_step_avg /= len(categories)
print(f'AVG loss {loss_step_avg}')
model.save_pretrained('./ckpt')
📄 ライセンス
このプロジェクトはMITライセンスの下で公開されています。