🚀 基於ResNet - 50骨幹網絡的條件DETR模型
該模型是基於條件檢測變換器(DETR),在COCO 2017目標檢測數據集(11.8萬張帶標註圖像)上進行端到端訓練得到的。它由Meng等人在論文《用於快速訓練收斂的條件DETR》中提出,並首次在此倉庫發佈。
🚀 快速開始
你可以使用該原始模型進行目標檢測。前往模型中心查看所有可用的條件DETR模型。
from transformers import AutoImageProcessor, ConditionalDetrForObjectDetection
import torch
from PIL import Image
import requests
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
processor = AutoImageProcessor.from_pretrained("microsoft/conditional-detr-resnet-50")
model = ConditionalDetrForObjectDetection.from_pretrained("microsoft/conditional-detr-resnet-50")
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.7)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"檢測到 {model.config.id2label[label.item()]},置信度為 "
f"{round(score.item(), 3)},位置為 {box}"
)
上述代碼的輸出結果如下:
檢測到 remote,置信度為 0.833,位置為 [38.31, 72.1, 177.63, 118.45]
檢測到 cat,置信度為 0.831,位置為 [9.2, 51.38, 321.13, 469.0]
檢測到 cat,置信度為 0.804,位置為 [340.3, 16.85, 642.93, 370.95]
目前,特徵提取器和模型均支持PyTorch。
✨ 主要特性
- 快速收斂:實驗結果表明,對於骨幹網絡R50和R101,條件DETR的收斂速度快6.7倍;對於更強的骨幹網絡DC5 - R50和DC5 - R101,收斂速度快10倍。
- 創新機制:提出條件交叉注意力機制,通過學習條件空間查詢,減少對內容嵌入的依賴,降低訓練難度。
📚 詳細文檔
模型描述
最近提出的DETR方法將Transformer的編碼器 - 解碼器架構應用於目標檢測,並取得了不錯的效果。在本文中,作者解決了訓練收斂慢這一關鍵問題,並提出了一種用於DETR快速訓練的條件交叉注意力機制。
DETR中的交叉注意力高度依賴內容嵌入來定位物體的四個端點並預測邊界框,這增加了對高質量內容嵌入的需求,從而加大了訓練難度。條件DETR方法從解碼器嵌入中學習條件空間查詢,用於解碼器多頭交叉注意力。其優點在於,通過條件空間查詢,每個交叉注意力頭能夠關注包含不同區域(如一個物體端點或物體框內的一個區域)的條帶。這縮小了用於物體分類和邊界框迴歸的不同區域的定位空間範圍,從而減輕了對內容嵌入的依賴,簡化了訓練過程。

預期用途與限制
你可以使用該原始模型進行目標檢測。
訓練數據
條件DETR模型在COCO 2017目標檢測數據集上進行訓練,該數據集分別包含11.8萬張和5000張帶標註圖像用於訓練和驗證。
BibTeX引用與引用信息
@inproceedings{MengCFZLYS021,
author = {Depu Meng and
Xiaokang Chen and
Zejia Fan and
Gang Zeng and
Houqiang Li and
Yuhui Yuan and
Lei Sun and
Jingdong Wang},
title = {Conditional {DETR} for Fast Training Convergence},
booktitle = {2021 {IEEE/CVF} International Conference on Computer Vision, {ICCV}
2021, Montreal, QC, Canada, October 10-17, 2021},
}
📄 許可證
本項目採用Apache - 2.0許可證。
📦 信息表格
屬性 |
詳情 |
模型類型 |
基於ResNet - 50骨幹網絡的條件DETR模型 |
訓練數據 |
COCO 2017目標檢測數據集(11.8萬張訓練圖像和5000張驗證圖像) |
標籤 |
目標檢測、視覺 |