🚀 M3D:借助多模态大语言模型推动3D医学图像分析发展
M3D是首个全面致力于3D医学分析的多模态大语言模型系列工作,旨在解决3D医学图像分析中的复杂问题,为医学研究和临床应用提供强大的支持。
论文 | 数据 | 代码
✨ 主要特性
M3D系列工作涵盖了数据集、模型和评估基准三个关键部分,具体如下:
- M3D-Data:这是目前最大规模的开源3D医学数据集,包含120K图像-文本对和662K指令-响应对,为模型训练提供了丰富的数据资源。
- M3D-LaMed:基于M3D-CLIP预训练视觉编码器的多模态模型,具备图像-文本检索、报告生成、视觉问答、定位和分割等多种任务能力。
- M3D-Bench:最全面的自动评估基准,涵盖8个任务,可有效评估模型在不同任务上的性能。
⚠️ 重要提示
我们发现之前的GoodBaiBai88/M3D-LaMed-Llama-2-7B模型在分割任务中存在问题。目前已修复该问题,并将在未来几天内重新发布新模型。
🚀 快速开始
我们可以基于Hugging Face轻松使用我们的模型。
基础用法
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import simple_slice_viewer as ssv
import SimpleITK as sikt
device = torch.device('cuda')
dtype = torch.bfloat16
model_name_or_path = 'GoodBaiBai88/M3D-LaMed-Llama-2-7B'
proj_out_num = 256
image_path = "./Data/data/examples/example_01.npy"
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype=dtype,
device_map='auto',
trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
model_name_or_path,
model_max_length=512,
padding_side="right",
use_fast=False,
trust_remote_code=True
)
model = model.to(device=device)
question = "What is liver in this image? Please output the segmentation mask."
image_tokens = "<im_patch>" * proj_out_num
input_txt = image_tokens + question
input_id = tokenizer(input_txt, return_tensors="pt")['input_ids'].to(device=device)
image_np = np.load(image_path)
image_pt = torch.from_numpy(image_np).unsqueeze(0).to(dtype=dtype, device=device)
generation, seg_logit = model.generate(image_pt, input_id, seg_enable=True, max_new_tokens=256, do_sample=True, top_p=0.9, temperature=1.0)
generated_texts = tokenizer.batch_decode(generation, skip_special_tokens=True)
seg_mask = (torch.sigmoid(seg_logit) > 0.5) * 1.0
print('question', question)
print('generated_texts', generated_texts[0])
image = sikt.GetImageFromArray(image_np)
ssv.display(image)
seg = sikt.GetImageFromArray(seg_mask.cpu().numpy()[0])
ssv.display(seg)
📄 许可证
本项目采用Apache-2.0许可证。