🚀 DETR(端到端目標檢測)模型
本項目是一個基於ResNet - 50骨幹網絡的DETR(端到端目標檢測)模型,在SKU110K數據集上進行訓練,設置了400個查詢數(num_queries)。該模型能夠有效解決目標檢測問題,為相關領域的應用提供了強大的支持。
🚀 快速開始
DETR(Detection Transformer)模型在SKU110K目標檢測數據集(包含8000張標註圖像)上進行了端到端的訓練。與原始模型相比,主要區別在於設置了400個查詢數(num_queries),並且在SKU110K數據集上進行了預訓練。
模型使用方法
以下是使用該模型的示例代碼:
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image, ImageOps
import requests
url = "https://github.com/Isalia20/DETR-finetune/blob/main/IMG_3507.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
image = ImageOps.exif_transpose(image)
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
model = DetrForObjectDetection.from_pretrained("isalia99/detr-resnet-50-sku110k")
model = model.eval()
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.8)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
代碼運行後,預期輸出如下:
Detected LABEL_1 with confidence 0.983 at location [665.49, 480.05, 708.15, 650.11]
Detected LABEL_1 with confidence 0.938 at location [204.99, 1405.9, 239.9, 1546.5]
...
Detected LABEL_1 with confidence 0.998 at location [772.85, 169.49, 829.67, 372.18]
Detected LABEL_1 with confidence 0.999 at location [828.28, 1475.16, 874.37, 1593.43]
目前,特徵提取器和模型均支持PyTorch。
📚 詳細文檔
訓練數據
DETR模型在SKU110K數據集上進行訓練,該數據集分別包含8219/588/2936張標註圖像用於訓練/驗證/測試。
訓練過程
訓練
模型在1塊RTX 4060 Ti GPU上進行訓練,前140個epoch僅微調解碼器,批量大小為8;後70個epoch微調整個網絡,批量大小為3,並進行3步梯度累積。
評估結果
該模型在SKU110k驗證集上實現了58.9的平均精度均值(mAP)。結果使用torchmetrics的MeanAveragePrecision類進行計算。
訓練代碼
訓練代碼已發佈在本倉庫 倉庫鏈接。不過,代碼尚未最終確定和充分測試,但主要功能已包含在內。
📄 許可證
本項目採用Apache - 2.0許可證。