模型简介
模型特点
模型能力
使用案例
🚀 Simple Diffusion XS
小体积,高质量
Simple Diffusion XS 是一个由小团队打造的文本到图像生成模型,旨在以有限预算创建紧凑且快速的模型,可在消费级显卡上完成全量训练。
🚀 快速开始
训练状态
训练状态,已暂停:第 N 16 个 epoch
训练结果示例
💻 使用示例
基础用法
import torch
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import AutoModel, AutoTokenizer
from PIL import Image
from tqdm.auto import tqdm
import os
def encode_prompt(prompt, negative_prompt, device, dtype):
if negative_prompt is None:
negative_prompt = ""
with torch.no_grad():
positive_inputs = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=512,
truncation=True,
).to(device)
positive_embeddings = text_model.encode_texts(
positive_inputs.input_ids, positive_inputs.attention_mask
)
if positive_embeddings.ndim == 2:
positive_embeddings = positive_embeddings.unsqueeze(1)
positive_embeddings = positive_embeddings.to(device, dtype=dtype)
negative_inputs = tokenizer(
negative_prompt,
return_tensors="pt",
padding="max_length",
max_length=512,
truncation=True,
).to(device)
negative_embeddings = text_model.encode_texts(negative_inputs.input_ids, negative_inputs.attention_mask)
if negative_embeddings.ndim == 2:
negative_embeddings = negative_embeddings.unsqueeze(1)
negative_embeddings = negative_embeddings.to(device, dtype=dtype)
return torch.cat([negative_embeddings, positive_embeddings], dim=0)
def generate_latents(embeddings, height=576, width=576, num_inference_steps=50, guidance_scale=5.5):
with torch.no_grad():
device, dtype = embeddings.device, embeddings.dtype
half = embeddings.shape[0] // 2
latent_shape = (half, 16, height // 8, width // 8)
latents = torch.randn(latent_shape, device=device, dtype=dtype)
embeddings = embeddings.repeat_interleave(half, dim=0)
scheduler.set_timesteps(num_inference_steps)
for t in tqdm(scheduler.timesteps, desc="Генерация"):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
noise_pred = unet(latent_model_input, t, embeddings).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents).prev_sample
return latents
def decode_latents(latents, vae, output_type="pil"):
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
with torch.no_grad():
images = vae.decode(latents).sample
images = (images / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
images = (images * 255).round().astype("uint8")
images = [Image.fromarray(image) for image in images]
return images
# Example usage:
if __name__ == "__main__":
device = "cuda"
dtype = torch.float16
prompt = "кот"
negative_prompt = "bad quality"
tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip")
text_model = AutoModel.from_pretrained(
"visheratin/mexma-siglip", torch_dtype=dtype, trust_remote_code=True
).to(device, dtype=dtype).eval()
embeddings = encode_prompt(prompt, negative_prompt, device, dtype)
pipeid = "AiArtLab/sdxs"
variant = "fp16"
unet = UNet2DConditionModel.from_pretrained(pipeid, subfolder="unet", variant=variant).to(device, dtype=dtype).eval()
vae = AutoencoderKL.from_pretrained(pipeid, subfolder="vae", variant=variant).to(device, dtype=dtype).eval()
scheduler = DDPMScheduler.from_pretrained(pipeid, subfolder="scheduler")
height, width = 576, 576
num_inference_steps = 40
output_folder, project_name = "samples", "sdxs"
latents = generate_latents(
embeddings=embeddings,
height=height,
width=width,
num_inference_steps = num_inference_steps
)
images = decode_latents(latents, vae)
os.makedirs(output_folder, exist_ok=True)
for idx, image in enumerate(images):
image.save(f"{output_folder}/{project_name}_{idx}.jpg")
print("Images generated and saved to:", output_folder)
✨ 主要特性
- 轻量快速:专为消费级显卡设计,可在 16GB GPU 上完成全量训练,训练速度快。
- 多语言支持:采用多语言编码器 Mexma - SigLIP,支持 80 种语言。
- 高质量生成:虽模型体积小,但生成质量有望与 SDXL 相媲美。
- 多模态潜力:可同步文本嵌入和图像,支持图像嵌入与文本描述混合查询。
📚 详细文档
简介
我们是 AiArtLab,一个预算有限的小团队。我们的目标是创建一个紧凑且快速的模型,能够在消费级显卡上进行全量训练(而非 LoRA)。我们选择 U - Net 是因为它能够有效处理小数据集,即使在 16GB GPU(如 RTX 4080)上也能快速训练。由于预算仅数千美元,远低于 SDXL 等竞争对手(数千万美元),所以我们决定创建一个小而高效的模型,类似于 2015 年的 SD1.5 版本。
编码器架构(文本和图像)
我们对各种编码器进行了实验,发现像 LLaMA 或 T5 XXL 这样的大型模型对于高质量生成并非必需。我们需要一个能够理解查询上下文的编码器,更注重“提示理解”而非“提示跟随”。因此,我们选择了多语言编码器 Mexma - SigLIP,它支持 80 种语言,并且处理句子而非单个标记。Mexma 最多接受 512 个标记,会生成一个大矩阵,从而减慢训练速度。所以,我们使用了一个池化层,将 512x1152 的矩阵简化为 1x1152 的向量。具体来说,我们将其通过一个线性模型/文本投影器,以实现与 SigLIP 嵌入的兼容性。这使我们能够同步文本嵌入和图像,有望实现统一的多模态模型。该功能允许在查询中混合图像嵌入和文本描述。此外,该模型可以仅使用图像进行训练,而无需文本描述。这将简化视频训练(视频标注具有挑战性),并通过输入带有衰减的前一帧嵌入来实现更一致、无缝的视频生成。未来,我们计划将模型扩展到 3D/视频生成。
U - Net 架构
我们选择了平滑的通道金字塔:[384, 576, 768, 960],每个块有两层,以及 [4, 6, 8, 10] 个变换器,每个变换器有 1152/48 = 24 个注意力头。这种架构在模型大小约为 20 亿参数的情况下提供了最高的训练速度(并且非常适合我的 RTX 4080)。我们相信,由于其更大的“深度”,尽管“尺寸”较小,但其质量将与 SDXL 相当。通过添加一个 1152 层,模型可以扩展到 40 亿参数,实现与嵌入大小的完美对称,我们认为这种对称很优雅,并且可能达到“Flux/MJ 级别”的质量。
VAE 架构
我们选择了非常规的 8x 16 通道 AuraDiffusion VAE,它能够保留细节、文本和人体结构,没有 SD3/Flux 那种“模糊”的问题。我们使用了带有 FFN 卷积的快速版本,观察到在精细图案上有轻微的纹理损坏,这可能会降低其在基准测试中的评分。像 ESRGAN 这样的超分辨率器可以解决这些伪影问题。总体而言,我们认为这个 VAE 被严重低估了。
训练过程
优化器
我们测试了几种优化器(AdamW、Laion、Optimi - AdamW、Adafactor 和 AdamW - 8bit),最终选择了 AdamW - 8bit。Optimi - AdamW 的梯度衰减曲线最平滑,尽管 AdamW - 8bit 的表现更不稳定。然而,它的体积更小,允许更大的批量大小,从而在低成本 GPU 上最大化训练速度(我们使用 4 个 A6000 和 5 个 L40 进行训练)。
学习率
我们发现调整衰减/预热曲线有一定效果,但并不显著。最佳学习率往往被高估。我们的实验表明,Adam 允许较宽的学习率范围。我们从 1e - 4 开始,在训练过程中逐渐降低到 1e - 6。换句话说,选择正确的模型架构远比调整超参数重要。
数据集
我们在大约 100 万张图像上训练了该模型:在 256 分辨率的 ImageNet 上训练了 60 个 epoch(由于低质量标注浪费了时间),在 CaptionEmporium/midjourney - niji - 1m - llavanext 上训练了 8 个 epoch,以及在 576 分辨率的真实照片和动漫/艺术图像上进行训练。我们使用人工提示、Caption Emporium 提供的提示、SmilingWolf 的 WD - Tagger 和 Moondream2 进行标注,通过改变提示长度和组成来确保模型理解不同的提示风格。数据集非常小,导致模型错过许多实体,并且在处理未见概念(如“自行车上的鹅”)时存在困难。数据集中还包含许多女仆风格的图像,因为我们更关注模型学习人体结构的能力,而不是绘制“骑马的宇航员”的技能。虽然大多数描述是英文的,但我们的测试表明该模型支持多语言。
局限性
- 概念覆盖有限:由于数据集极小,模型对概念的覆盖有限。
- 图像到图像功能待完善:图像到图像功能需要进一步训练(我们将 SigLIP 部分减少到 5%,以专注于文本到图像的训练)。
致谢
- Stan — 关键投资者。提供主要资金支持,感谢您在他人认为这是疯狂之举时对我们的信任。
- Captainsaturnus — 物质支持。
- Lovescape & Whargarbl — 精神支持。
- CaptionEmporium — 数据集提供方。
"我们相信未来在于高效、紧凑的模型。感谢您的捐赠,希望继续得到您的支持。"
训练预算
捐赠
如果您能提供 GPU 或资金用于训练,请与我们联系。
狗狗币(DOGE)地址:DEw2DR8C7BnF8GgcrfTzUjSnGkuMeJhg83 比特币(BTC)地址:3JHv9Hb8kEW8zMAccdgCdZGfrHeMhH1rpN
联系方式
📄 许可证
本项目采用 Apache - 2.0 许可证。









