🚀 鸟类分类器EfficientNet - B2
本项目的鸟类分类器基于EfficientNet - B2模型,可利用图像分类技术准确识别鸟类物种,为鸟类识别提供了高效、准确的解决方案。
🚀 快速开始
本模型可用于图像分类任务。以下是使用该模型对鸟类图片进行分类的示例代码。
基础用法
import torch
import urllib.request
from PIL import Image
from transformers import EfficientNetImageProcessor, EfficientNetForImageClassification
url = 'some url'
img = Image.open(urllib.request.urlretrieve(url)[0])
preprocessor = EfficientNetImageProcessor.from_pretrained("dennisjooo/Birds-Classifier-EfficientNetB2")
model = EfficientNetForImageClassification.from_pretrained("dennisjooo/Birds-Classifier-EfficientNetB2")
inputs = preprocessor(img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])
高级用法
import torch
import urllib.request
from PIL import Image
from transformers import pipeline
url = 'some url'
img = Image.open(urllib.request.urlretrieve(url)[0])
pipe = pipeline("image-classification", model="dennisjooo/Birds-Classifier-EfficientNetB2")
result = pipe(img)[0]
print(result['label'])
✨ 主要特性
- 高精度识别:在训练集、验证集和测试集上均表现出色,训练集准确率达0.999480,验证集准确率达0.985904,测试集准确率达0.991238。
- 基于优秀基础模型:该模型是在[google/efficientnet - b2](https://huggingface.co/google/efficientnet - b2)基础上微调而来,原模型在ImageNet - 1K上训练,具备一定的特征识别能力。
📚 详细文档
模型描述
你是否曾看着一只鸟,心想“要是我知道这是什么鸟就好了”?除非你是狂热的观鸟者(或者只是单纯热爱鸟类),否则很难区分某些鸟类物种。不过你很幸运,事实证明可以使用图像分类器来识别鸟类物种!
本模型是[google/efficientnet - b2](https://huggingface.co/google/efficientnet - b2)在Kaggle上的[gpiosenka/100 - bird - species](https://www.kaggle.com/datasets/gpiosenka/100 - bird - species)数据集上的微调版本。用于训练模型的数据集于2023年9月24日获取。
原始模型本身在ImageNet - 1K上进行训练,因此它可能仍然具有一些用于识别鸟类等生物的有用特征。
理论上,在该数据集上随机猜测的准确率为0.0019047619(本质上是1/525)。该模型在所有三个数据集上的表现都非常好,结果如下:
- 训练集:0.999480
- 验证集:0.985904
- 测试集:0.991238
预期用途
你可以使用原始模型进行图像分类。上述代码示例展示了模型的具体使用方式。
训练与评估
数据
数据集来自Kaggle上的[gpiosenka/100 - bird - species](https://www.kaggle.com/datasets/gpiosenka/100 - bird - species)。它包含525种鸟类,有84635张训练图像,验证集和测试集各有2625张图像。数据集中的每张图像都是224×224的RGB图像。
训练过程使用了作者提供的相同数据划分。更多详细信息,请参考[作者的Kaggle页面](https://www.kaggle.com/datasets/gpiosenka/100 - bird - species)。
训练过程
训练使用PyTorch在Kaggle的免费P100 GPU上进行。该过程还使用了Lightning和Torchmetrics库。
预处理
每张图像根据原作者的[配置](https://huggingface.co/google/efficientnet - b2/blob/main/preprocessor_config.json)进行预处理。
训练集还使用了以下数据增强方法:
- 以50%的概率随机旋转10度
- 以50%的概率随机水平翻转
训练超参数
以下是训练使用的超参数:
属性 |
详情 |
训练模式 |
fp32 |
损失函数 |
交叉熵 |
优化器 |
Adam(默认betas为(0.99, 0.999)) |
学习率 |
1e - 3 |
学习率调度器 |
Reduce on plateau(监控验证损失,耐心值为2,衰减率为0.1) |
批量大小 |
64 |
提前停止策略 |
监控验证准确率,耐心值为10 |
结果
下图是训练过程在训练集和验证集上的结果:

📄 许可证
本项目采用Apache - 2.0许可证。