🚀 潛在擴散模型 (LDM)
潛在擴散模型(LDM)通過將圖像形成過程分解為去噪自編碼器的順序應用,在圖像數據及其他領域實現了最先進的合成效果。該模型在有限計算資源下進行訓練的同時,還能保留高質量和靈活性。
🚀 快速開始
論文信息
- 論文標題:High-Resolution Image Synthesis with Latent Diffusion Models
- 摘要:通過將圖像形成過程分解為去噪自編碼器的順序應用,擴散模型(DMs)在圖像數據及其他領域取得了最先進的合成效果。此外,其公式允許在不重新訓練的情況下使用引導機制來控制圖像生成過程。然而,由於這些模型通常直接在像素空間中運行,強大的擴散模型的優化通常需要消耗數百個GPU天,並且由於順序評估,推理成本很高。為了在有限的計算資源上進行擴散模型訓練,同時保留其質量和靈活性,我們將它們應用於強大的預訓練自編碼器的潛在空間中。與以往的工作相比,在這種表示上訓練擴散模型首次使得在複雜度降低和細節保留之間達到接近最優的平衡點成為可能,從而大大提高了視覺保真度。通過在模型架構中引入交叉注意力層,我們將擴散模型轉變為適用於一般條件輸入(如文本或邊界框)的強大而靈活的生成器,並且以卷積方式實現高分辨率合成成為可能。我們的潛在擴散模型(LDMs)在圖像修復方面達到了新的技術水平,並且在各種任務(包括無條件圖像生成、語義場景合成和超分辨率)上表現出極具競爭力的性能,同時與基於像素的擴散模型相比,顯著降低了計算要求。
- 作者:Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer
💻 使用示例
基礎用法
使用管道進行推理
!pip install diffusers
from diffusers import DiffusionPipeline
model_id = "CompVis/ldm-celebahq-256"
pipeline = DiffusionPipeline.from_pretrained(model_id)
image = pipeline(num_inference_steps=200)["sample"]
image[0].save("ldm_generated_image.png")
高級用法
使用展開循環進行推理
!pip install diffusers
from diffusers import UNet2DModel, DDIMScheduler, VQModel
import torch
import PIL.Image
import numpy as np
import tqdm
seed = 3
unet = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
unet.to(torch_device)
vqvae.to(torch_device)
generator = torch.manual_seed(seed)
noise = torch.randn(
(1, unet.in_channels, unet.sample_size, unet.sample_size),
generator=generator,
).to(torch_device)
scheduler.set_timesteps(num_inference_steps=200)
image = noise
for t in tqdm.tqdm(scheduler.timesteps):
with torch.no_grad():
residual = unet(image, t)["sample"]
prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
image = prev_image
with torch.no_grad():
image = vqvae.decode(image)
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save(f"generated_image_{seed}.png")
🔍 模型生成示例




📄 許可證
本項目採用Apache-2.0許可證。