🚀 蝴蝶生成對抗網絡(Butterfly GAN)
蝴蝶生成對抗網絡(Butterfly GAN)是一個用於生成蝴蝶圖像的模型。它基於特定的論文進行開發,能夠在較少的訓練樣本和較短的訓練時間內實現較好的圖像生成效果,可用於娛樂和學習。
🚀 快速開始
模型使用示例
import torch
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
gan = LightweightGAN.from_pretrained("ceyda/butterfly_cropped_uniq1K_512")
gan.eval()
batch_size = 1
with torch.no_grad():
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255
ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
✨ 主要特性
- 訓練效率高:基於論文 Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis,該模型在單個 RTX - 2080 GPU 上僅需幾小時的訓練就能從初始狀態收斂。
- 小樣本表現好:即使訓練樣本少於 100 個,模型也能保持穩定的性能。
- 代碼可復現:使用從 lucidrains 倉庫 改編的腳本進行訓練,同時採用了官方倉庫的變換操作。
📦 安裝指南
文檔未提及具體安裝步驟,可參考 社區活動倉庫 進行安裝。
💻 使用示例
基礎用法
import torch
from huggan.pytorch.lightweight_gan.lightweight_gan import LightweightGAN
gan = LightweightGAN.from_pretrained("ceyda/butterfly_cropped_uniq1K_512")
gan.eval()
batch_size = 1
with torch.no_grad():
ims = gan.G(torch.randn(batch_size, gan.latent_dim)).clamp_(0., 1.)*255
ims = ims.permute(0,2,3,1).detach().cpu().numpy().astype(np.uint8)
📚 詳細文檔
模型描述
該模型基於 論文 Towards Faster and Stabilized GAN Training for High-fidelity Few-shot Image Synthesis,也被稱為 Light - GAN 模型。此模型使用 這裡 的腳本進行訓練,該腳本改編自 lucidrains 倉庫。與上述腳本不同的是,使用了官方倉庫的變換操作,因為訓練圖像已經經過裁剪和對齊。官方論文實現 倉庫。
transform_list = [
transforms.Resize((int(im_size),int(im_size))),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]
預期用途與限制
- 預期用途:用於娛樂和學習。
- 限制和偏差:
- 訓練時,對數據集進行過濾,每個物種僅保留 1 只蝴蝶,否則模型生成的蝴蝶多樣性會降低。
- 數據集還使用 CLIP 分數對 ['pretty butterfly', 'one butterfly', 'butterfly with open wings', 'colorful butterfly'] 進行過濾,這種方法可能會導致某些蝴蝶因被判定為不符合標準而被排除在數據集之外,從而產生偏差。
訓練數據
使用了 1000 張圖像進行訓練。雖然可以增加圖像數量,但沒有時間手動整理數據集,同時也想驗證論文中提到的低數據訓練的可行性。更多細節可查看 數據卡片。
訓練過程
在 2 個 A4000 GPU 上訓練約 1 天,7 - 12 小時內即可看到較好的效果。重要參數:"--batch_size 64 --gradient_accumulate_every 4 --image_size 512 --mixed_precision fp16"。訓練日誌可查看 這裡。
評估結果
在 100 張圖像上計算了 FID 分數,不同檢查點的結果可查看 這裡,但由於 FID 分數的侷限性,其意義可能不大。
生成圖像
可以在 演示 中體驗模型生成的蝴蝶圖像。
BibTeX 引用和引用信息
該模型在 HugGAN 衝刺活動中開發。
模型訓練者:Ceyda Cinarel https://twitter.com/ceyda_cinarel
額外貢獻者:Jonathan Whitaker https://twitter.com/johnowhitaker
📄 許可證
本項目採用 MIT 許可證。
屬性 |
詳情 |
模型類型 |
蝴蝶生成對抗網絡(Butterfly GAN) |
訓練數據 |
1000 張蝴蝶圖像,來自 huggan/smithsonian_butterflies_subset 數據集 |