🚀 鳥類分類器EfficientNet - B2
本項目的鳥類分類器基於EfficientNet - B2模型,可利用圖像分類技術準確識別鳥類物種,為鳥類識別提供了高效、準確的解決方案。
🚀 快速開始
本模型可用於圖像分類任務。以下是使用該模型對鳥類圖片進行分類的示例代碼。
基礎用法
import torch
import urllib.request
from PIL import Image
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification
url = 'some url'
img = Image.open(urllib.request.urlretrieve(url)[0])
preprocessor = EfficientNetImageProcessor.from_pretrained("dennisjooo/Birds-Classifier-EfficientNetB2")
model = EfficientNetForImageClassification.from_pretrained("dennisjooo/Birds-Classifier-EfficientNetB2")
inputs = preprocessor(img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])
高級用法
import torch
import urllib.request
from PIL import Image
from transformers import pipeline
url = 'some url'
img = Image.open(urllib.request.urlretrieve(url)[0])
pipe = pipeline("image-classification", model="dennisjooo/Birds-Classifier-EfficientNetB2")
result = pipe(img)[0]
print(result['label'])
✨ 主要特性
- 高精度識別:在訓練集、驗證集和測試集上均表現出色,訓練集準確率達0.999480,驗證集準確率達0.985904,測試集準確率達0.991238。
- 基於優秀基礎模型:該模型是在[google/efficientnet - b2](https://huggingface.co/google/efficientnet - b2)基礎上微調而來,原模型在ImageNet - 1K上訓練,具備一定的特徵識別能力。
📚 詳細文檔
模型描述
你是否曾看著一隻鳥,心想“要是我知道這是什麼鳥就好了”?除非你是狂熱的觀鳥者(或者只是單純熱愛鳥類),否則很難區分某些鳥類物種。不過你很幸運,事實證明可以使用圖像分類器來識別鳥類物種!
本模型是[google/efficientnet - b2](https://huggingface.co/google/efficientnet - b2)在Kaggle上的[gpiosenka/100 - bird - species](https://www.kaggle.com/datasets/gpiosenka/100 - bird - species)數據集上的微調版本。用於訓練模型的數據集於2023年9月24日獲取。
原始模型本身在ImageNet - 1K上進行訓練,因此它可能仍然具有一些用於識別鳥類等生物的有用特徵。
理論上,在該數據集上隨機猜測的準確率為0.0019047619(本質上是1/525)。該模型在所有三個數據集上的表現都非常好,結果如下:
- 訓練集:0.999480
- 驗證集:0.985904
- 測試集:0.991238
預期用途
你可以使用原始模型進行圖像分類。上述代碼示例展示了模型的具體使用方式。
訓練與評估
數據
數據集來自Kaggle上的[gpiosenka/100 - bird - species](https://www.kaggle.com/datasets/gpiosenka/100 - bird - species)。它包含525種鳥類,有84635張訓練圖像,驗證集和測試集各有2625張圖像。數據集中的每張圖像都是224×224的RGB圖像。
訓練過程使用了作者提供的相同數據劃分。更多詳細信息,請參考[作者的Kaggle頁面](https://www.kaggle.com/datasets/gpiosenka/100 - bird - species)。
訓練過程
訓練使用PyTorch在Kaggle的免費P100 GPU上進行。該過程還使用了Lightning和Torchmetrics庫。
預處理
每張圖像根據原作者的[配置](https://huggingface.co/google/efficientnet - b2/blob/main/preprocessor_config.json)進行預處理。
訓練集還使用了以下數據增強方法:
- 以50%的概率隨機旋轉10度
- 以50%的概率隨機水平翻轉
訓練超參數
以下是訓練使用的超參數:
屬性 |
詳情 |
訓練模式 |
fp32 |
損失函數 |
交叉熵 |
優化器 |
Adam(默認betas為(0.99, 0.999)) |
學習率 |
1e - 3 |
學習率調度器 |
Reduce on plateau(監控驗證損失,耐心值為2,衰減率為0.1) |
批量大小 |
64 |
提前停止策略 |
監控驗證準確率,耐心值為10 |
結果
下圖是訓練過程在訓練集和驗證集上的結果:

📄 許可證
本項目採用Apache - 2.0許可證。