🚀 基於Wikiart數據集的DiT模型
本模型是一個基於Wikiart數據集從頭開始訓練的DiT(擴散變壓器)模型,旨在根據藝術流派和風格生成藝術圖像,為藝術創作和圖像生成領域提供了新的可能性。
🚀 快速開始
要使用該模型,你需要安裝 huggingface_hub
庫,並從 “Files and versions” 中下載 modeling_dit_wikiart.py
用於模型定義。之後,你可以使用以下代碼來使用該模型:
from modeling_dit_wikiart import DiTWikiartModel
model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Large")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)
該模型與 stabilityai/sd-vae-ft-ema
搭配使用。
✨ 主要特性
- 該模型能夠理解藝術流派和風格,並生成具有一定視覺吸引力的繪畫作品。
- 模型有三種不同大小的變體可供選擇,以滿足不同的應用需求。
📦 安裝指南
安裝 huggingface_hub
庫:
pip install huggingface_hub
💻 使用示例
基礎用法
from modeling_dit_wikiart import DiTWikiartModel
model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Large")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)
📚 詳細文檔
模型描述
此模型是一個在Wikiart數據集(https://huggingface.co/datasets/Artificio/WikiArt )上從頭開始訓練的DiT(擴散變壓器)模型,旨在根據藝術流派和風格生成藝術圖像。
模型架構
該模型在很大程度上借鑑了論文 Scalable Diffusion Models with Transformers 中描述的經典DiT架構,並進行了一些細微的修改:
- 用Wikiart的流派和風格嵌入替換了ImageNet類嵌入;
- 使用後歸一化(post-norm)代替前歸一化(pre-norm);
- 省略了最後的線性層;
- 用學習到的位置嵌入替換了正弦 - 餘弦二維位置嵌入;
- 模型僅預測噪聲,不學習sigma;
- 所有模型變體的
patch_size
都設置為2;
- 模型有不同的大小設置。
如果你感興趣,請查看本倉庫中的
modeling_dit_wikiart.py
以獲取更多詳細信息。
該模型有三種變體:
- S:小型,
num_blocks
= 8,hidden_size
= 384,num_heads
= 6,總參數 = 20M;
- B:基礎型,
num_blocks
= 12,hidden_size
= 640,num_heads
= 10,總參數 = 90M;
- L:大型,
num_blocks
= 16,hidden_size
= 896,num_heads
= 14,總參數 = 234M。
訓練過程
- 數據集:所有模型變體都在103K的Wikiart數據集上進行訓練,並通過水平翻轉進行數據增強。
- 優化器:使用默認設置的AdamW優化器。
- 學習率:在前1%的步驟中進行線性熱身,學習率達到最大值3e - 4,然後在後續步驟中以餘弦衰減至零。
- 訓練輪數和批次大小:
- S:96輪,批次大小為176;
- B:120輪,批次大小為192;
- L:144輪,批次大小為192。
- 設備:
- S:單張RTX 4060ti 16G訓練24小時;
- B:單張RTX 4060ti 16G訓練90小時;
- L:單張RTX 4090D 24G訓練48小時,然後單張RTX 4060ti 16G訓練100小時。
- 損失曲線:所有變體在第一個訓練輪中損失從1.0000以上急劇下降到約0.2000,隨後下降速度變慢,最終在第20輪達到損失 = 0.1600。DiT - S最終達到0.1590;DiT - B最終達到0.1525;DiT - L最終達到0.1510。訓練過程穩定,沒有損失峰值。
性能和侷限性
- 模型展示了理解流派和風格並生成具有視覺吸引力繪畫的基本能力(乍一看)。
- 侷限性包括:
- 無法理解複雜結構,如人臉、建築物等;
- 在要求生成數據集中罕見的流派或風格時,偶爾會出現模式崩潰,例如極簡主義風格和浮世繪流派;
- 分辨率限制為256x256;
- 由於在Wikiart數據集上訓練,因此無法生成超出該數據集範圍的圖像。
🔧 技術細節
該模型基於擴散變壓器架構,在Wikiart數據集上進行訓練。通過對經典DiT架構的修改,使其更適合藝術圖像生成任務。在訓練過程中,使用了數據增強、AdamW優化器和特定的學習率策略,以確保模型的穩定性和性能。
📄 許可證
本模型使用MIT許可證。
屬性 |
詳情 |
模型類型 |
擴散變壓器(DiT) |
訓練數據 |
Wikiart數據集(https://huggingface.co/datasets/Artificio/WikiArt ) |