🚀 DiT-Wikiart-Small 模型
本模型是一個基於擴散變壓器(Diffusion Transformer)架構的模型,專門用於無條件圖像生成。它在Wikiart數據集上從頭開始訓練,能夠根據藝術流派和風格生成藝術圖像。
🚀 快速開始
要使用此模型,你需要安裝 huggingface_hub
庫,並從“文件和版本”中下載 modeling_dit_wikiart.py
用於模型定義。之後,你可以使用以下代碼來使用該模型:
from modeling_dit_wikiart import DiTWikiartModel
model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Small")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)
該模型與 stabilityai/sd-vae-ft-ema
搭配使用。
✨ 主要特性
- 基於擴散變壓器(DiT)架構,在Wikiart數據集上從頭開始訓練。
- 能夠根據給定的藝術流派和風格生成藝術圖像。
- 模型有三種不同大小的變體,以滿足不同的需求。
📦 安裝指南
要使用此模型,你需要安裝 huggingface_hub
庫,並從“文件和版本”中下載 modeling_dit_wikiart.py
用於模型定義。
- 庫鏈接:https://hf-mirror.com/kaupane/DiT-Wikiart-Small
💻 使用示例
基礎用法
from modeling_dit_wikiart import DiTWikiartModel
model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Small")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)
📚 詳細文檔
模型描述
本模型是一個在Wikiart數據集(https://huggingface.co/datasets/Artificio/WikiArt )上從頭開始訓練的DiT(擴散變壓器)模型。它旨在根據藝術流派和風格生成藝術圖像。
模型架構
該模型在很大程度上借鑑了論文 Scalable Diffusion Models with Transformers 中描述的經典DiT架構,並進行了一些細微的修改:
- 用Wikiart的流派和風格嵌入替換了ImageNet的類別嵌入;
- 使用後歸一化(post-norm)代替前歸一化(pre-norm);
- 省略了最後的線性層;
- 用學習到的位置嵌入替換了正弦 - 餘弦二維位置嵌入;
- 模型僅預測噪聲,不學習sigma;
- 所有模型變體的
patch_size
都設置為2;
- 模型有不同的大小設置。
如果你感興趣,可以查看此倉庫中的 modeling_dit_wikiart.py
以獲取更多詳細信息。
模型有三種變體:
- S:小型,
num_blocks=8
,hidden_size=384
,num_heads=6
,總參數為20M;
- B:基礎型,
num_blocks=12
,hidden_size=640
,num_heads=10
,總參數為90M;
- L:大型,
num_blocks=16
,hidden_size=896
,num_heads=14
,總參數為234M。
訓練過程
- 數據集:所有模型變體都在103K的Wikiart數據集上進行訓練,並通過水平翻轉進行數據增強。
- 優化器:使用默認設置的AdamW優化器。
- 學習率:在前1%的步驟中進行線性熱身,學習率達到最大值3e-4,然後在後續步驟中進行餘弦衰減至零。
- 訓練輪數和批次大小:
- S:96輪,批次大小為176;
- B:120輪,批次大小為192;
- L:144輪,批次大小為192。
- 設備:
- S:單張RTX 4060ti 16G,訓練24小時;
- B:單張RTX 4060ti 16G,訓練90小時;
- L:先使用單張RTX 4090D 24G訓練48小時,再使用單張RTX 4060ti 16G訓練100小時。
- 損失曲線:所有變體在第一個訓練輪次中損失從1.0000以上急劇下降到0.2000左右,隨後下降速度變慢,在第20個訓練輪次時最終達到損失值0.1600。DiT-S最終達到0.1590;DiT-B最終達到0.1525;DiT-L最終達到0.1510。
性能和侷限性
- 模型展示了理解流派和風格並生成視覺上吸引人的繪畫的基本能力(乍一看)。
- 侷限性包括:
- 無法理解複雜的結構,如人臉、建築物等。
- 在要求生成數據集中罕見的流派或風格時,偶爾會出現模式崩潰。例如極簡主義風格和浮世繪流派。
- 分辨率限制為256x256。
- 由於在Wikiart數據集上訓練,因此無法生成超出該數據集範圍的圖像。
📄 許可證
本模型採用MIT許可證。
📋 模型信息