🚀 DETR(端到端目标检测)模型:基于ResNet - 101 - DC5骨干网络,在SKU110K数据集上训练,num_queries为400
本项目的DETR(Detection Transformer)模型在SKU110K目标检测数据集(包含8000张带注释图像)上进行了端到端训练。与原始模型的主要区别在于,本模型的num_queries
设置为400,并且在SKU110K数据集上进行了预训练。
🚀 快速开始
模型使用方法
以下是使用该模型的示例代码:
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image, ImageOps
import requests
url = "https://github.com/Isalia20/DETR-finetune/blob/main/IMG_3507.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
image = ImageOps.exif_transpose(image)
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101-dc5")
model = DetrForObjectDetection.from_pretrained("isalia99/detr-resnet-101-dc5-sku110k")
model = model.eval()
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.8)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
运行上述代码,输出示例如下:
Detected LABEL_1 with confidence 0.983 at location [665.49, 480.05, 708.15, 650.11]
Detected LABEL_1 with confidence 0.938 at location [204.99, 1405.9, 239.9, 1546.5]
...
Detected LABEL_1 with confidence 0.998 at location [772.85, 169.49, 829.67, 372.18]
Detected LABEL_1 with confidence 0.999 at location [828.28, 1475.16, 874.37, 1593.43]
目前,特征提取器和模型均支持PyTorch。
📚 详细文档
训练数据
DETR模型在SKU110K数据集上进行训练,该数据集分别包含8219/588/2936张带注释的图像用于训练、验证和测试。
训练过程
- 训练设置:模型在1块RTX 4060 Ti GPU上进行训练。首先仅对解码器进行微调,训练60个epoch,批次大小为1,梯度累积步数设置为8;然后对整个网络进行微调,同样训练60个epoch,批次大小为1,梯度累积步数为8。
评估结果
该模型在SKU110K验证集上的平均精度均值(mAP)达到了59.8。评估结果使用torchmetrics
库中的MeanAveragePrecision
类进行计算。
训练代码
训练代码已发布在本仓库中,点击查看。不过,代码尚未最终确定和充分测试,但主要功能已包含在内。
📄 许可证
本项目采用Apache - 2.0许可证。