🚀 vit-artworkclassifier
このモデルは、入力された任意の画像のアートワークスタイルを返します。
このモデルは、google/vit-base-patch16-224-in21k を imagefolder データセットでファインチューニングしたバージョンです。これは artbench - 10 データセット (https://www.kaggle.com/datasets/alexanderliao/artbench10) のサブセットで、各クラスに 1000 件のアートワークを含むトレーニングセットと、各クラスに 100 件のアートワークを含む検証セットがあります。
評価セットでは、以下の結果を達成しています。
📚 ドキュメント
モデルの説明
このモデルがトレーニングされたプロジェクトの説明はこちらにあります: https://medium.com/@oliverpj.schamp/training-and-evaluating-stable-diffusion-for-artwork-generation-b099d1f5b7a6
想定される用途と制限
このモデルは、artbench - 10 の 10 クラスのうち 9 クラスのみを含んでおり、浮世絵 (ukiyo_e) は含まれていません。これは、データの入手可能性とフォーマットの問題によるものです。
トレーニングと評価データ
トレーニング: artbench - 10 からランダムに選択された 1000 枚の画像 (各クラス)。検証: artbench - 10 からランダムに選択された 100 枚の画像 (各クラス)。
トレーニング手順
トレーニングハイパーパラメータ
トレーニング中に使用されたハイパーパラメータは以下の通りです。
- 学習率: 0.0001
- トレーニングバッチサイズ: 32
- 評価バッチサイズ: 8
- シード: 42
- オプティマイザ: Adam (ベータ=(0.9, 0.999)、イプシロン=1e - 08)
- 学習率スケジューラのタイプ: 線形
- エポック数: 4
- 混合精度トレーニング: Native AMP
トレーニング結果
トレーニング損失 |
エポック |
ステップ |
検証損失 |
正解率 |
1.5906 |
0.36 |
100 |
1.4709 |
0.4847 |
1.3395 |
0.72 |
200 |
1.3208 |
0.5074 |
1.1461 |
1.08 |
300 |
1.3363 |
0.5165 |
0.9593 |
1.44 |
400 |
1.1790 |
0.5846 |
0.8761 |
1.8 |
500 |
1.1252 |
0.5902 |
0.5922 |
2.16 |
600 |
1.1392 |
0.5948 |
0.4803 |
2.52 |
700 |
1.1560 |
0.5936 |
0.4454 |
2.88 |
800 |
1.1545 |
0.6118 |
0.2271 |
3.24 |
900 |
1.2284 |
0.6039 |
0.207 |
3.6 |
1000 |
1.2625 |
0.5959 |
0.1958 |
3.96 |
1100 |
1.2621 |
0.6005 |
フレームワークのバージョン
- Transformers 4.26.1
- Pytorch 1.13.1+cu117
- Datasets 2.9.0
- Tokenizers 0.13.2
実行コード
def vit_classify(image):
vit = ViTForImageClassification.from_pretrained("oschamp/vit-artworkclassifier")
vit.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vit.to(device)
model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
encoding = feature_extractor(images=image, return_tensors="pt")
encoding.keys()
pixel_values = encoding['pixel_values'].to(device)
outputs = vit(pixel_values)
logits = outputs.logits
prediction = logits.argmax(-1)
return prediction.item()
📄 ライセンス
このモデルは Apache - 2.0 ライセンスの下で提供されています。
📦 モデル情報
属性 |
詳情 |
モデルタイプ |
画像分類モデル |
トレーニングデータ |
artbench - 10 データセットのサブセット |