🚀 ResNet50 v1.5 圖像分類模型
ResNet50 v1.5 是一款用於圖像分類的模型,它在原始 ResNet50 v1 模型基礎上進行了改進,提升了一定的準確性,可利用 NVIDIA GPU 架構的 Tensor Cores 進行混合精度訓練,還能部署在 NVIDIA Triton 推理服務器上進行推理。
🚀 快速開始
本模型可用於圖像分類任務,下面將介紹如何使用預訓練的 ResNet50 v1.5 模型對圖像進行推理並展示結果。
✨ 主要特性
- 改進版本:ResNet50 v1.5 是 原始 ResNet50 v1 模型 的改進版本,在瓶頸塊的下采樣操作上與 v1 有所不同,使得其準確率比 v1 略高(約 0.5% top1),但性能略有下降(約 5% imgs/sec)。
- 混合精度訓練:該模型使用 Volta、Turing 和 NVIDIA Ampere GPU 架構上的 Tensor Cores 進行混合精度訓練,研究人員可以比不使用 Tensor Cores 時快 2 倍以上得到結果,同時體驗混合精度訓練的好處。
- 一致性測試:該模型針對每個 NGC 月度容器版本進行測試,以確保隨著時間的推移保持一致的準確性和性能。
- 可部署性:ResNet50 v1.5 模型可以使用 TorchScript、ONNX Runtime 或 TensorRT 作為執行後端,部署在 NVIDIA Triton 推理服務器 上進行推理。
📦 安裝指南
運行示例需要安裝一些額外的 Python 包,用於圖像預處理和可視化:
!pip install validators matplotlib
💻 使用示例
基礎用法
以下是使用預訓練的 ResNet50 v1.5 模型對圖像進行推理的示例代碼:
import torch
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import json
import requests
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f'Using {device} for inference')
加載在 IMAGENET 數據集上預訓練的模型:
resnet50 = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_resnet50', pretrained=True)
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_convnets_processing_utils')
resnet50.eval().to(device)
準備樣本輸入數據:
uris = [
'http://images.cocodataset.org/test-stuff2017/000000024309.jpg',
'http://images.cocodataset.org/test-stuff2017/000000028117.jpg',
'http://images.cocodataset.org/test-stuff2017/000000006149.jpg',
'http://images.cocodataset.org/test-stuff2017/000000004954.jpg',
]
batch = torch.cat(
[utils.prepare_input_from_uri(uri) for uri in uris]
).to(device)
運行推理,使用 pick_n_best(predictions=output, n=topN)
輔助函數根據模型選擇 N 個最可能的假設:
with torch.no_grad():
output = torch.nn.functional.softmax(resnet50(batch), dim=1)
results = utils.pick_n_best(predictions=output, n=5)
顯示結果:
for uri, result in zip(uris, results):
img = Image.open(requests.get(uri, stream=True).raw)
img.thumbnail((256,256), Image.ANTIALIAS)
plt.imshow(img)
plt.show()
print(result)
📚 詳細文檔
有關模型輸入和輸出、訓練方法、推理和性能的詳細信息,請訪問:
github
和/或 NGC
📄 許可證
本項目採用 Apache-2.0 許可證。
🔗 參考資料