🚀 DETR (エンドツーエンド物体検出) モデル(ResNet-101-DC5バックボーン、SKU110Kデータセットで学習、num_queries=400)
このDETR(Detection Transformer)モデルは、SKU110K物体検出データセット(8000枚以上の注釈付き画像)でエンドツーエンドに学習されています。オリジナルモデルとの主な違いは、num_queries
が400であり、SKU110Kデータセットで事前学習されている点です。
🚀 クイックスタート
このモデルの使用方法は以下の通りです。
基本的な使用法
from transformers import DetrImageProcessor, DetrForObjectDetection
import torch
from PIL import Image, ImageOps
import requests
url = "https://github.com/Isalia20/DETR-finetune/blob/main/IMG_3507.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
image = ImageOps.exif_transpose(image)
processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101-dc5")
model = DetrForObjectDetection.from_pretrained("isalia99/detr-resnet-101-dc5-sku110k")
model = model.eval()
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
target_sizes = torch.tensor([image.size[::-1]])
results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.8)[0]
for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
box = [round(i, 2) for i in box.tolist()]
print(
f"Detected {model.config.id2label[label.item()]} with confidence "
f"{round(score.item(), 3)} at location {box}"
)
このコードの出力例は以下の通りです。
Detected LABEL_1 with confidence 0.983 at location [665.49, 480.05, 708.15, 650.11]
Detected LABEL_1 with confidence 0.938 at location [204.99, 1405.9, 239.9, 1546.5]
...
Detected LABEL_1 with confidence 0.998 at location [772.85, 169.49, 829.67, 372.18]
Detected LABEL_1 with confidence 0.999 at location [828.28, 1475.16, 874.37, 1593.43]
現在、特徴抽出器とモデルの両方がPyTorchをサポートしています。
📚 ドキュメント
学習データ
DETRモデルはSKU110Kデータセットで学習されています。このデータセットは、学習/検証/テスト用にそれぞれ8,219/588/2,936枚の注釈付き画像で構成されています。
学習手順
学習
このモデルは、1台のRTX 4060 Ti GPUで学習されました。最初にデコーダのみをファインチューニングし、バッチサイズ1、勾配累積ステップ8で60エポック学習しました。その後、ネットワーク全体をファインチューニングし、同じバッチサイズと勾配累積ステップでさらに60エポック学習しました。
評価結果
このモデルは、SKU110k検証セットで59.8のmAP(平均平均精度)を達成しています。結果はtorchmetrics
のMeanAveragePrecision
クラスを使用して計算されました。
学習コード
学習コードはこのリポジトリ Repo Link で公開されています。ただし、まだ最終版ではなく、十分にテストされていない部分もありますが、主な機能はコードに含まれています。
📄 ライセンス
このプロジェクトはApache-2.0ライセンスの下で公開されています。