🚀 基于ResNet - 50骨干网络的条件DETR模型
该模型是基于条件检测变换器(DETR),在COCO 2017目标检测数据集(11.8万张带标注图像)上进行端到端训练得到的。它由Meng等人在论文《用于快速训练收敛的条件DETR》中提出,并首次在此仓库发布。
🚀 快速开始
你可以使用该原始模型进行目标检测。前往模型中心查看所有可用的条件DETR模型。
from transformers import AutoImageProcessor, ConditionalDetrForObjectDetection
import torch
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50")
model = ConditionalDetrForObjectDetection.from_pretrained("microsoft/conditional-detr-resnet-50")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"检测到 {model.config.id2label[label.item()]},置信度为 "
f"{round(score.item(), 3)},位置为 {box}"
)
上述代码的输出结果如下:
检测到 remote,置信度为 0.833,位置为 [38.31, 72.1, 177.63, 118.45]
检测到 cat,置信度为 0.831,位置为 [9.2, 51.38, 321.13, 469.0]
检测到 cat,置信度为 0.804,位置为 [340.3, 16.85, 642.93, 370.95]
目前,特征提取器和模型均支持PyTorch。
✨ 主要特性
- 快速收敛:实验结果表明,对于骨干网络R50和R101,条件DETR的收敛速度快6.7倍;对于更强的骨干网络DC5 - R50和DC5 - R101,收敛速度快10倍。
- 创新机制:提出条件交叉注意力机制,通过学习条件空间查询,减少对内容嵌入的依赖,降低训练难度。
📚 详细文档
模型描述
最近提出的DETR方法将Transformer的编码器 - 解码器架构应用于目标检测,并取得了不错的效果。在本文中,作者解决了训练收敛慢这一关键问题,并提出了一种用于DETR快速训练的条件交叉注意力机制。
DETR中的交叉注意力高度依赖内容嵌入来定位物体的四个端点并预测边界框,这增加了对高质量内容嵌入的需求,从而加大了训练难度。条件DETR方法从解码器嵌入中学习条件空间查询,用于解码器多头交叉注意力。其优点在于,通过条件空间查询,每个交叉注意力头能够关注包含不同区域(如一个物体端点或物体框内的一个区域)的条带。这缩小了用于物体分类和边界框回归的不同区域的定位空间范围,从而减轻了对内容嵌入的依赖,简化了训练过程。

预期用途与限制
你可以使用该原始模型进行目标检测。
训练数据
条件DETR模型在COCO 2017目标检测数据集上进行训练,该数据集分别包含11.8万张和5000张带标注图像用于训练和验证。
BibTeX引用与引用信息
@inproceedings{MengCFZLYS021,
author = {Depu Meng and
Xiaokang Chen and
Zejia Fan and
Gang Zeng and
Houqiang Li and
Yuhui Yuan and
Lei Sun and
Jingdong Wang},
title = {Conditional {DETR} for Fast Training Convergence},
booktitle = {2021 {IEEE/CVF} International Conference on Computer Vision, {ICCV}
2021, Montreal, QC, Canada, October 10-17, 2021},
}
📄 许可证
本项目采用Apache - 2.0许可证。
📦 信息表格
属性 |
详情 |
模型类型 |
基于ResNet - 50骨干网络的条件DETR模型 |
训练数据 |
COCO 2017目标检测数据集(11.8万张训练图像和5000张验证图像) |
标签 |
目标检测、视觉 |