🚀 视觉艺术作品分类器(vit-artworkclassifier)
本模型能够识别任意输入图像的艺术风格,为艺术图像分类提供了高效的解决方案。它基于预训练模型进行微调,在特定数据集上取得了一定的评估结果,具有一定的实用价值。
🚀 快速开始
本模型是 google/vit-base-patch16-224-in21k 在 imagefolder
数据集上的微调版本。该数据集是 artbench - 10 数据集(https://www.kaggle.com/datasets/alexanderliao/artbench10)的子集,每个类别包含 1000 张图像的训练集和 100 张图像的验证集。
模型在评估集上取得了以下结果:
✨ 主要特性
- 图像风格识别:可准确识别输入图像的艺术风格。
- 微调优化:基于预训练模型微调,在特定数据集上有较好表现。
📚 详细文档
模型描述
你可以在以下链接找到该模型训练项目的相关描述:https://medium.com/@oliverpj.schamp/training-and-evaluating-stable-diffusion-for-artwork-generation-b099d1f5b7a6
预期用途与限制
本模型仅包含 artbench - 10 数据集中的 9 个类别,不包含“ukiyo_e”类别,这是由于数据可用性和格式问题导致的。
训练和评估数据
- 训练数据:从 artbench - 10 中随机选择 1000 张图像(每个类别)。
- 验证数据:从 artbench - 10 中随机选择 100 张图像(每个类别)。
训练过程
训练超参数
训练过程中使用了以下超参数:
- 学习率:0.0001
- 训练批次大小:32
- 评估批次大小:8
- 随机种子:42
- 优化器:Adam(β1 = 0.9,β2 = 0.999,ε = 1e - 08)
- 学习率调度器类型:线性
- 训练轮数:4
- 混合精度训练:Native AMP
训练结果
训练损失 |
轮数 |
步数 |
验证损失 |
准确率 |
1.5906 |
0.36 |
100 |
1.4709 |
0.4847 |
1.3395 |
0.72 |
200 |
1.3208 |
0.5074 |
1.1461 |
1.08 |
300 |
1.3363 |
0.5165 |
0.9593 |
1.44 |
400 |
1.1790 |
0.5846 |
0.8761 |
1.8 |
500 |
1.1252 |
0.5902 |
0.5922 |
2.16 |
600 |
1.1392 |
0.5948 |
0.4803 |
2.52 |
700 |
1.1560 |
0.5936 |
0.4454 |
2.88 |
800 |
1.1545 |
0.6118 |
0.2271 |
3.24 |
900 |
1.2284 |
0.6039 |
0.207 |
3.6 |
1000 |
1.2625 |
0.5959 |
0.1958 |
3.96 |
1100 |
1.2621 |
0.6005 |
框架版本
- Transformers 4.26.1
- Pytorch 1.13.1+cu117
- Datasets 2.9.0
- Tokenizers 0.13.2
💻 使用示例
基础用法
def vit_classify(image):
vit = ViTForImageClassification.from_pretrained("oschamp/vit-artworkclassifier")
vit.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vit.to(device)
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
encoding = feature_extractor(images=image, return_tensors="pt")
encoding.keys()
pixel_values = encoding['pixel_values'].to(device)
outputs = vit(pixel_values)
logits = outputs.logits
prediction = logits.argmax(-1)
return prediction.item()
📄 许可证
本模型采用 Apache - 2.0 许可证。
属性 |
详情 |
模型类型 |
基于视觉变换器的图像分类模型 |
训练数据 |
artbench - 10 数据集的子集 |
评估指标 |
准确率 |
基础模型 |
google/vit - base - patch16 - 224 - in21k |