đ SegVol: Universal Interactive Model for Volumetric Medical Image Segmentation
SegVol is a universal and interactive model designed for volumetric medical image segmentation. It accepts point, box, and text prompts and outputs volumetric segmentation results. Trained on 90k unlabeled Computed Tomography (CT) volumes and 6k labeled CTs, this foundation model supports the segmentation of over 200 anatomical categories.
Paper and Code have been released.
Keywords: 3D medical SAM, volumetric image segmentation

đ Quick Start
đĻ Requirements
conda create -n segvol_transformers python=3.8
conda activate segvol_transformers
The pytorch v1.11.0 (or a higher version) is required. Install the key requirements using the following commands:
pip install 'monai[all]==0.9.0'
pip install einops==0.6.1
pip install transformers==4.18.0
pip install matplotlib
đģ Usage Examples
đ Basic Usage - Test Script
from transformers import AutoModel, AutoTokenizer
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
clip_tokenizer = AutoTokenizer.from_pretrained("yuxindu/segvol")
model = AutoModel.from_pretrained("yuxindu/segvol", trust_remote_code=True, test_mode=True)
model.model.text_encoder.tokenizer = clip_tokenizer
model.eval()
model.to(device)
print('model load done')
ct_path = 'path/to/Case_image_00001_0000.nii.gz'
gt_path = 'path/to/Case_label_00001.nii.gz'
categories = ["liver", "kidney", "spleen", "pancreas"]
ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
data_item = model.processor.zoom_transform(ct_npy, gt_npy)
data_item['image'], data_item['label'], data_item['zoom_out_image'], data_item['zoom_out_label'] = \
data_item['image'].unsqueeze(0).to(device), data_item['label'].unsqueeze(0).to(device), data_item['zoom_out_image'].unsqueeze(0).to(device), data_item['zoom_out_label'].unsqueeze(0).to(device)
cls_idx = 0
text_prompt = [categories[cls_idx]]
point_prompt, point_prompt_map = model.processor.point_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device)
bbox_prompt, bbox_prompt_map = model.processor.bbox_prompt_b(data_item['zoom_out_label'][0][cls_idx], device=device)
print('prompt done')
logits_mask = model.forward_test(image=data_item['image'],
zoomed_image=data_item['zoom_out_image'],
bbox_prompt_group=[bbox_prompt, bbox_prompt_map],
text_prompt=text_prompt,
use_zoom=False
)
dice = model.processor.dice_score(logits_mask[0][0], data_item['label'][0][cls_idx])
print(dice)
save_path='./Case_preds_00001.nii.gz'
model.processor.save_preds(ct_path, save_path, logits_mask[0][0],
start_coord=data_item['foreground_start_coord'],
end_coord=data_item['foreground_end_coord'])
print('done')
âī¸ Advanced Usage - Training Script
from transformers import AutoModel, AutoTokenizer
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
clip_tokenizer = AutoTokenizer.from_pretrained("yuxindu/segvol")
model = AutoModel.from_pretrained("yuxindu/segvol", trust_remote_code=True, test_mode=False)
model.model.text_encoder.tokenizer = clip_tokenizer
model.train()
model.to(device)
print('model load done')
ct_path = 'path/to/Case_image_00001_0000.nii.gz'
gt_path = 'path/to/Case_label_00001.nii.gz'
categories = ["liver", "kidney", "spleen", "pancreas"]
ct_npy, gt_npy = model.processor.preprocess_ct_gt(ct_path, gt_path, category=categories)
data_item = model.processor.train_transform(ct_npy, gt_npy)
image, gt3D = data_item["image"].unsqueeze(0).to(device), data_item["label"].unsqueeze(0).to(device)
loss_step_avg = 0
for cls_idx in range(len(categories)):
organs_cls = categories[cls_idx]
labels_cls = gt3D[:, cls_idx]
print(image.shape, organs_cls, labels_cls.shape)
loss = model.forward_train(image, train_organs=organs_cls, train_labels=labels_cls)
loss_step_avg += loss.item()
loss.backward()
loss_step_avg /= len(categories)
print(f'AVG loss {loss_step_avg}')
model.save_pretrained('./ckpt')
đ License
This project is licensed under the MIT License.