🚀 クイックスタート
このモデルは、Wikiartデータセットhttps://huggingface.co/datasets/Artificio/WikiArt を使用してゼロから学習されたDiT(拡散トランスフォーマー)です。美術ジャンルと美術スタイルを指定することで、美術画像を生成するように設計されています。
✨ 主な機能
- 美術ジャンルとスタイルを指定して美術画像を生成できます。
- 3つのモデルバリアント(S, B, L)があり、それぞれ異なるサイズ設定です。
📦 インストール
モデルを使用するには、"huggingface_hub" ライブラリをインストールし、"Files and versions" からモデル定義用のmodeling_dit_wikiart.pyをダウンロードします。
💻 使用例
基本的な使用法
from modeling_dit_wikiart import DiTWikiartModel
model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Large")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)
高度な使用法
モデルはstabilityai/sd-vae-ft-emaとペアになっています。これを利用してさらに高度な画像生成を行うことができます。
📚 ドキュメント
モデルのアーキテクチャ
このモデルは、論文 Scalable Diffusion Models with Transformers で説明されている古典的なDiTアーキテクチャを大まかに模倣しており、いくつかの小さな変更が加えられています。
- ImageNetクラスの埋め込みをWikiartのジャンルとスタイルの埋め込みに置き換えました。
- 事前正規化ではなく事後正規化を使用しました。
- 最終の線形層を省略しました。
- sin-cos-2d位置埋め込みを学習済みの位置埋め込みに置き換えました。
- モデルはノイズのみを予測し、sigmaを学習しません。
- すべてのモデルバリアントでpatch_size=2に設定しました。
- モデルには異なるサイズ設定があります。
モデルには3つのバリアントがあります。
- S: 小型、num_blocks=8、hidden_size=384、num_heads=6、総パラメータ数=20M
- B: ベース、num_blocks=12、hidden_size=640、num_heads=10、総パラメータ数=90M
- L: 大型、num_blocks=16、hidden_size=896、num_heads=14、総パラメータ数=234M
詳細については、このリポジトリ内のmodeling_dit_wikiart.pyを確認してください。
学習手順
- データセット: すべてのモデルバリアントは、水平反転によるデータ拡張を行った103KのWikiartデータセットで学習されました。
- オプティマイザ: デフォルト設定のAdamW。
- 学習率: 最初の1%のステップで線形ウォームアップを行い、学習率は最大3e-4に達し、その後のステップでコサイン減衰してゼロになります。
- エポック数とバッチサイズ:
- S: バッチサイズ176で96エポック
- B: バッチサイズ192で120エポック
- L: バッチサイズ192で144エポック
- デバイス:
- S: 単一のRTX 4060ti 16Gで24時間
- B: 単一のRTX 4060ti 16Gで90時間
- L: 単一のRTX 4090D 24Gで48時間、その後単一のRTX 4060ti 16Gで100時間
- 損失曲線: すべてのバリアントは、最初のエポックで損失が1.0000以上から0.2000付近まで大幅に減少し、その後ははるかにゆっくりと減少し、20エポックで最終的に損失=0.1600に達しました。DiT-Sは最終的に0.1590に達し、DiT-Bは最終的に0.1525に達し、DiT-Lは最終的に0.1510に達しました。学習は安定しており、損失の急上昇はありません。
性能と制限
- モデルは、ジャンルとスタイルを理解し、視覚的に魅力的な絵画を生成する基本的な能力を示しています(一見すると)。
- 制限事項には以下のものがあります。
- 人の顔や建物などの複雑な構造を理解できません。
- データセットでめったに見られないジャンルやスタイルを生成するように要求された場合、時々モーダル崩壊が発生します。たとえば、ミニマリズムのスタイルや浮世絵のジャンルなどです。
- 解像度は256x256に制限されています。
- Wikiartデータセットで学習されているため、範囲外の画像を生成することはできません。
🔧 技術詳細
このモデルは、Wikiartデータセットを使用してゼロから学習されたDiT(拡散トランスフォーマー)です。古典的なDiTアーキテクチャをベースに、いくつかの小さな変更が加えられています。学習手順では、データ拡張やオプティマイザ、学習率の設定などが行われており、損失曲線は安定しています。
📄 ライセンス
このモデルはMITライセンスの下で提供されています。
プロパティ |
詳細 |
モデルタイプ |
DiT(拡散トランスフォーマー) |
学習データ |
103KのWikiartデータセット(水平反転によるデータ拡張あり) |