🚀 蝴蝶生成对抗网络(Butterfly GAN)
蝴蝶生成对抗网络(Butterfly GAN)是一个用于生成蝴蝶图像的模型。它基于特定的论文进行开发,能够在较少的训练样本和较短的训练时间内实现较好的图像生成效果,可用于娱乐和学习。
🚀 快速开始
模型使用示例
import torch
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
gan = LightweightGAN.from_pretrained("ceyda/butterfly_cropped_uniq1K_512")
gan.eval()
batch_size = 1
with torch.no_grad():
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255
ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
✨ 主要特性
- 训练效率高:基于论文 Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis,该模型在单个 RTX - 2080 GPU 上仅需几小时的训练就能从初始状态收敛。
- 小样本表现好:即使训练样本少于 100 个,模型也能保持稳定的性能。
- 代码可复现:使用从 lucidrains 仓库 改编的脚本进行训练,同时采用了官方仓库的变换操作。
📦 安装指南
文档未提及具体安装步骤,可参考 社区活动仓库 进行安装。
💻 使用示例
基础用法
import torch
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
gan = LightweightGAN.from_pretrained("ceyda/butterfly_cropped_uniq1K_512")
gan.eval()
batch_size = 1
with torch.no_grad():
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255
ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
📚 详细文档
模型描述
该模型基于 论文 Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis,也被称为 Light - GAN 模型。此模型使用 这里 的脚本进行训练,该脚本改编自 lucidrains 仓库。与上述脚本不同的是,使用了官方仓库的变换操作,因为训练图像已经经过裁剪和对齐。官方论文实现 仓库。
transform_list = [
transforms.Resize((int(im_size),int(im_size))),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]
预期用途与限制
- 预期用途:用于娱乐和学习。
- 限制和偏差:
- 训练时,对数据集进行过滤,每个物种仅保留 1 只蝴蝶,否则模型生成的蝴蝶多样性会降低。
- 数据集还使用 CLIP 分数对 ['pretty butterfly', 'one butterfly', 'butterfly with open wings', 'colorful butterfly'] 进行过滤,这种方法可能会导致某些蝴蝶因被判定为不符合标准而被排除在数据集之外,从而产生偏差。
训练数据
使用了 1000 张图像进行训练。虽然可以增加图像数量,但没有时间手动整理数据集,同时也想验证论文中提到的低数据训练的可行性。更多细节可查看 数据卡片。
训练过程
在 2 个 A4000 GPU 上训练约 1 天,7 - 12 小时内即可看到较好的效果。重要参数:"--batch_size 64 --gradient_accumulate_every 4 --image_size 512 --mixed_precision fp16"。训练日志可查看 这里。
评估结果
在 100 张图像上计算了 FID 分数,不同检查点的结果可查看 这里,但由于 FID 分数的局限性,其意义可能不大。
生成图像
可以在 演示 中体验模型生成的蝴蝶图像。
BibTeX 引用和引用信息
该模型在 HugGAN 冲刺活动中开发。
模型训练者:Ceyda Cinarel https://twitter.com/ceyda_cinarel
额外贡献者:Jonathan Whitaker https://twitter.com/johnowhitaker
📄 许可证
本项目采用 MIT 许可证。
属性 |
详情 |
模型类型 |
蝴蝶生成对抗网络(Butterfly GAN) |
训练数据 |
1000 张蝴蝶图像,来自 huggan/smithsonian_butterflies_subset 数据集 |