🚀 潜在扩散模型 (LDM)
潜在扩散模型(LDM)通过将图像形成过程分解为去噪自编码器的顺序应用,在图像数据及其他领域实现了最先进的合成效果。该模型在有限计算资源下进行训练的同时,还能保留高质量和灵活性。
🚀 快速开始
论文信息
- 论文标题:High-Resolution Image Synthesis with Latent Diffusion Models
- 摘要:通过将图像形成过程分解为去噪自编码器的顺序应用,扩散模型(DMs)在图像数据及其他领域取得了最先进的合成效果。此外,其公式允许在不重新训练的情况下使用引导机制来控制图像生成过程。然而,由于这些模型通常直接在像素空间中运行,强大的扩散模型的优化通常需要消耗数百个GPU天,并且由于顺序评估,推理成本很高。为了在有限的计算资源上进行扩散模型训练,同时保留其质量和灵活性,我们将它们应用于强大的预训练自编码器的潜在空间中。与以往的工作相比,在这种表示上训练扩散模型首次使得在复杂度降低和细节保留之间达到接近最优的平衡点成为可能,从而大大提高了视觉保真度。通过在模型架构中引入交叉注意力层,我们将扩散模型转变为适用于一般条件输入(如文本或边界框)的强大而灵活的生成器,并且以卷积方式实现高分辨率合成成为可能。我们的潜在扩散模型(LDMs)在图像修复方面达到了新的技术水平,并且在各种任务(包括无条件图像生成、语义场景合成和超分辨率)上表现出极具竞争力的性能,同时与基于像素的扩散模型相比,显著降低了计算要求。
- 作者:Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer
💻 使用示例
基础用法
使用管道进行推理
!pip install diffusers
from diffusers import DiffusionPipeline
model_id = "CompVis/ldm-celebahq-256"
pipeline = DiffusionPipeline.from_pretrained(model_id)
image = pipeline(num_inference_steps=200)["sample"]
image[0].save("ldm_generated_image.png")
高级用法
使用展开循环进行推理
!pip install diffusers
from diffusers import UNet2DModel, DDIMScheduler, VQModel
import torch
import PIL.Image
import numpy as np
import tqdm
seed = 3
unet = UNet2DModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="unet")
vqvae = VQModel.from_pretrained("CompVis/ldm-celebahq-256", subfolder="vqvae")
scheduler = DDIMScheduler.from_config("CompVis/ldm-celebahq-256", subfolder="scheduler")
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
unet.to(torch_device)
vqvae.to(torch_device)
generator = torch.manual_seed(seed)
noise = torch.randn(
(1, unet.in_channels, unet.sample_size, unet.sample_size),
generator=generator,
).to(torch_device)
scheduler.set_timesteps(num_inference_steps=200)
image = noise
for t in tqdm.tqdm(scheduler.timesteps):
with torch.no_grad():
residual = unet(image, t)["sample"]
prev_image = scheduler.step(residual, t, image, eta=0.0)["prev_sample"]
image = prev_image
with torch.no_grad():
image = vqvae.decode(image)
image_processed = image.cpu().permute(0, 2, 3, 1)
image_processed = (image_processed + 1.0) * 127.5
image_processed = image_processed.clamp(0, 255).numpy().astype(np.uint8)
image_pil = PIL.Image.fromarray(image_processed[0])
image_pil.save(f"generated_image_{seed}.png")
🔍 模型生成示例




📄 许可证
本项目采用Apache-2.0许可证。