🚀 基于Wikiart数据集的DiT模型
本模型是一个基于Wikiart数据集从头开始训练的DiT(扩散变压器)模型,旨在根据艺术流派和风格生成艺术图像,为艺术创作和图像生成领域提供了新的可能性。
🚀 快速开始
要使用该模型,你需要安装 huggingface_hub
库,并从 “Files and versions” 中下载 modeling_dit_wikiart.py
用于模型定义。之后,你可以使用以下代码来使用该模型:
from modeling_dit_wikiart import DiTWikiartModel
model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Large")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)
该模型与 stabilityai/sd-vae-ft-ema
搭配使用。
✨ 主要特性
- 该模型能够理解艺术流派和风格,并生成具有一定视觉吸引力的绘画作品。
- 模型有三种不同大小的变体可供选择,以满足不同的应用需求。
📦 安装指南
安装 huggingface_hub
库:
pip install huggingface_hub
💻 使用示例
基础用法
from modeling_dit_wikiart import DiTWikiartModel
model = DiTWikiartModel.from_pretrained("kaupane/DiT-Wikiart-Large")
num_samples = 8
noisy_latents = torch.randn(num_samples,4,32,32)
predicted_noise = model(noisy_latents)
print(predicted_noise)
📚 详细文档
模型描述
此模型是一个在Wikiart数据集(https://huggingface.co/datasets/Artificio/WikiArt )上从头开始训练的DiT(扩散变压器)模型,旨在根据艺术流派和风格生成艺术图像。
模型架构
该模型在很大程度上借鉴了论文 Scalable Diffusion Models with Transformers 中描述的经典DiT架构,并进行了一些细微的修改:
- 用Wikiart的流派和风格嵌入替换了ImageNet类嵌入;
- 使用后归一化(post-norm)代替前归一化(pre-norm);
- 省略了最后的线性层;
- 用学习到的位置嵌入替换了正弦 - 余弦二维位置嵌入;
- 模型仅预测噪声,不学习sigma;
- 所有模型变体的
patch_size
都设置为2;
- 模型有不同的大小设置。
如果你感兴趣,请查看本仓库中的
modeling_dit_wikiart.py
以获取更多详细信息。
该模型有三种变体:
- S:小型,
num_blocks
= 8,hidden_size
= 384,num_heads
= 6,总参数 = 20M;
- B:基础型,
num_blocks
= 12,hidden_size
= 640,num_heads
= 10,总参数 = 90M;
- L:大型,
num_blocks
= 16,hidden_size
= 896,num_heads
= 14,总参数 = 234M。
训练过程
- 数据集:所有模型变体都在103K的Wikiart数据集上进行训练,并通过水平翻转进行数据增强。
- 优化器:使用默认设置的AdamW优化器。
- 学习率:在前1%的步骤中进行线性热身,学习率达到最大值3e - 4,然后在后续步骤中以余弦衰减至零。
- 训练轮数和批次大小:
- S:96轮,批次大小为176;
- B:120轮,批次大小为192;
- L:144轮,批次大小为192。
- 设备:
- S:单张RTX 4060ti 16G训练24小时;
- B:单张RTX 4060ti 16G训练90小时;
- L:单张RTX 4090D 24G训练48小时,然后单张RTX 4060ti 16G训练100小时。
- 损失曲线:所有变体在第一个训练轮中损失从1.0000以上急剧下降到约0.2000,随后下降速度变慢,最终在第20轮达到损失 = 0.1600。DiT - S最终达到0.1590;DiT - B最终达到0.1525;DiT - L最终达到0.1510。训练过程稳定,没有损失峰值。
性能和局限性
- 模型展示了理解流派和风格并生成具有视觉吸引力绘画的基本能力(乍一看)。
- 局限性包括:
- 无法理解复杂结构,如人脸、建筑物等;
- 在要求生成数据集中罕见的流派或风格时,偶尔会出现模式崩溃,例如极简主义风格和浮世绘流派;
- 分辨率限制为256x256;
- 由于在Wikiart数据集上训练,因此无法生成超出该数据集范围的图像。
🔧 技术细节
该模型基于扩散变压器架构,在Wikiart数据集上进行训练。通过对经典DiT架构的修改,使其更适合艺术图像生成任务。在训练过程中,使用了数据增强、AdamW优化器和特定的学习率策略,以确保模型的稳定性和性能。
📄 许可证
本模型使用MIT许可证。
属性 |
详情 |
模型类型 |
扩散变压器(DiT) |
训练数据 |
Wikiart数据集(https://huggingface.co/datasets/Artificio/WikiArt ) |