🚀 实例分割示例
本项目专注于图像分割领域,借助预训练模型和特定数据集,实现了实例分割任务的微调训练与推理,为相关领域的研究和应用提供了实用的范例和代码参考。
🚀 快速开始
内容导航
✨ 主要特性
- 基于 🤗 Trainer API 管理训练,支持分布式环境。
- 对 Mask2Former 模型在 ADE20K 数据集的子样本上进行微调。
- 提供了完整的训练和推理代码示例。
📦 安装指南
此部分未提供具体安装命令,暂不展示安装指南。
💻 使用示例
基础用法
PyTorch 版本与训练器
本模型基于脚本 run_instance_segmentation.py
。该脚本使用 🤗 Trainer API 自动管理训练,包括分布式环境。
这里,我们在 ADE20K 数据集的子样本上微调 Mask2Former 模型。我们创建了一个 小数据集,约有 2000 张图像,仅包含“人”和“汽车”的标注;其他所有像素标记为“背景”。
以下是该模型的 label2id
映射:
label2id = {
"person": 0,
"car": 1,
}
使用以下命令进行训练:
python run_instance_segmentation.py \
--model_name_or_path facebook/mask2former-swin-tiny-coco-instance \
--output_dir finetune-instance-segmentation-ade20k-mini-mask2former \
--dataset_name qubvel-hf/ade20k-mini \
--do_reduce_labels \
--image_height 256 \
--image_width 256 \
--do_train \
--fp16 \
--num_train_epochs 40 \
--learning_rate 1e-5 \
--lr_scheduler_type constant \
--per_device_train_batch_size 8 \
--gradient_accumulation_steps 2 \
--dataloader_num_workers 8 \
--dataloader_persistent_workers \
--dataloader_prefetch_factor 4 \
--do_eval \
--evaluation_strategy epoch \
--logging_strategy epoch \
--save_strategy epoch \
--save_total_limit 2 \
--push_to_hub
高级用法
重新加载与推理
你可以轻松加载训练好的模型并进行推理,如下所示:
import torch
import requests
import matplotlib.pyplot as plt
from PIL import Image
from transformers import Mask2FormerForUniversalSegmentation, Mask2FormerImageProcessor
image = Image.open(requests.get("http://farm4.staticflickr.com/3017/3071497290_31f0393363_z.jpg", stream=True).raw)
device = "cuda"
checkpoint = "qubvel-hf/finetune-instance-segmentation-ade20k-mini-mask2former"
model = Mask2FormerForUniversalSegmentation.from_pretrained(checkpoint, device_map=device)
image_processor = Mask2FormerImageProcessor.from_pretrained(checkpoint)
inputs = image_processor(images=[image], return_tensors="pt").to(device)
with torch.no_grad():
outputs = model(**inputs)
outputs = image_processor.post_process_instance_segmentation(outputs, target_sizes=[image.size[::-1]])
print("Mask shape: ", outputs[0]["segmentation"].shape)
print("Mask values: ", outputs[0]["segmentation"].unique())
for segment in outputs[0]["segments_info"]:
print("Segment: ", segment)
运行上述代码后,输出示例如下:
Mask shape: torch.Size([427, 640])
Mask values: tensor([-1., 0., 1., 2., 3., 4., 5., 6.])
Segment: {'id': 0, 'label_id': 0, 'was_fused': False, 'score': 0.946127}
Segment: {'id': 1, 'label_id': 1, 'was_fused': False, 'score': 0.961582}
Segment: {'id': 2, 'label_id': 1, 'was_fused': False, 'score': 0.968367}
Segment: {'id': 3, 'label_id': 1, 'was_fused': False, 'score': 0.819527}
Segment: {'id': 4, 'label_id': 1, 'was_fused': False, 'score': 0.655761}
Segment: {'id': 5, 'label_id': 1, 'was_fused': False, 'score': 0.531299}
Segment: {'id': 6, 'label_id': 1, 'was_fused': False, 'score': 0.929477}
使用以下代码可视化结果:
import numpy as np
import matplotlib.pyplot as plt
segmentation = outputs[0]["segmentation"].numpy()
plt.figure(figsize=(10, 10))
plt.subplot(1, 2, 1)
plt.imshow(np.array(image))
plt.axis("off")
plt.subplot(1, 2, 2)
plt.imshow(segmentation)
plt.axis("off")
plt.show()

📚 详细文档
自定义数据说明
此部分未提供详细说明,暂不展示相关内容。
🔧 技术细节
此部分未提供具体的技术说明(内容少于 50 字),暂不展示技术细节。
📄 许可证
Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.