🚀 犬種多分類圖像識別模型
本項目基於視覺變換器(Vision Transformer)模型,對犬類圖像進行分類,可識別 120 種不同犬種。該模型使用了預訓練的 Google Vision Transformer 模型,並在斯坦福犬類數據集上進行微調,具有較高的準確性和良好的泛化能力。
🚀 快速開始
模型背景
最近,有人問我是否可以將犬類圖像分類為不同的犬種,而不是像我之前的 筆記本 那樣僅僅區分貓和狗。答案是肯定的!
由於問題的複雜性,我們將使用 2020 年 Google 論文 中發佈的最先進的計算機視覺架構——視覺變換器(Vision Transformer)。
模型原理
視覺變換器(Vision Transformer) 與傳統的 卷積神經網絡(CNN) 的區別在於對圖像的處理方式。在 視覺變換器 中,我們將輸入視為原始圖像的一個補丁(例如 16 x 16),並將其作為帶有位置嵌入和自注意力的序列輸入到變換器中;而在 卷積神經網絡(CNN) 中,我們使用相同的原始圖像補丁作為輸入,但使用卷積和池化層作為歸納偏置。這意味著 視覺變換器 可以使用其自注意力機制以“全局”方式關注圖像的任何特定補丁,而無需像 CNN 那樣通過“局部”居中/裁剪/邊界框來引導神經網絡進行卷積操作。
這使得 視覺變換器 架構在本質上更加靈活和可擴展,使我們能夠在計算機視覺中創建 基礎模型,類似於自然語言處理中的基礎模型,如 BERT 和 GPT,通過在大量圖像數據上進行預訓練(自監督/監督),可以推廣到不同的計算機視覺任務,如圖像分類、識別、分割等。這種交叉融合有助於我們更接近通用人工智能的目標。
需要注意的是,與 卷積神經網絡 相比,視覺變換器 的歸納偏置較弱,這使得它具有可擴展性和靈活性。但這一特點(或缺點,取決於你的看法)要求大多數表現良好的預訓練模型需要更多的數據,儘管與 CNN 相比,它的參數更少。
幸運的是,在這個模型中,我們將使用 Google 託管在 HuggingFace 上的 視覺變換器,該模型在 ImageNet-21k 數據集(1400 萬張圖像,21000 個類別)上進行了預訓練,補丁大小為 16x16,分辨率為 224x224,以繞過數據限制。我們將在來自 斯坦福犬類數據集 的約 20000 張圖像的“小”犬種數據集上對該模型進行微調,以將犬類圖像分類為 120 種不同的犬種!
✨ 主要特性
- 基於先進架構:採用視覺變換器(Vision Transformer)架構,具有更好的靈活性和可擴展性。
- 預訓練模型微調:使用在 ImageNet-21k 數據集上預訓練的 Google Vision Transformer 模型,在斯坦福犬類數據集上進行微調,提高模型性能。
- 多指標評估:使用 Top-1 準確率、Top-3 準確率、Top-5 準確率和 Macro F1 等多個指標對模型進行評估,確保模型的準確性和泛化能力。
📦 安裝指南
本模型使用 Python 編寫,依賴於 transformers
、PIL
和 requests
等庫。可以使用以下命令安裝所需的庫:
pip install transformers pillow requests
💻 使用示例
基礎用法
from transformers import AutoImageProcessor, AutoModelForImageClassification
import PIL
import requests
url = "https://upload.wikimedia.org/wikipedia/commons/5/55/Beagle_600.jpg"
image = PIL.Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcessor.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
model = AutoModelForImageClassification.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
📚 詳細文檔
模型描述
本模型使用 Google 視覺變換器(vit-base-patch16-224-in21k) 在 Kaggle 上的斯坦福犬類數據集 上進行微調,以將犬類圖像分類為 120 種不同的犬種。
預期用途和限制
你可以使用這個微調後的模型僅對數據集中包含的犬類圖像和犬種進行分類。
模型訓練指標
輪數 |
Top-1 準確率 |
Top-3 準確率 |
Top-5 準確率 |
Macro F1 |
1 |
79.8% |
95.1% |
97.5% |
77.2% |
2 |
83.8% |
96.7% |
98.2% |
81.9% |
3 |
84.8% |
96.7% |
98.3% |
83.4% |
模型評估指標
Top-1 準確率 |
Top-3 準確率 |
Top-5 準確率 |
Macro F1 |
84.0% |
97.1% |
98.7% |
83.0% |
📄 許可證
本項目採用 MIT 許可證。