模型概述
模型特點
模型能力
使用案例
🚀 Simple Diffusion XS
小體積,高質量
Simple Diffusion XS 是一個由小團隊打造的文本到圖像生成模型,旨在以有限預算創建緊湊且快速的模型,可在消費級顯卡上完成全量訓練。
🚀 快速開始
訓練狀態
訓練狀態,已暫停:第 N 16 個 epoch
訓練結果示例
💻 使用示例
基礎用法
import torch
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
from transformers import AutoModel, AutoTokenizer
from PIL import Image
from tqdm.auto import tqdm
import os
def encode_prompt(prompt, negative_prompt, device, dtype):
if negative_prompt is None:
negative_prompt = ""
with torch.no_grad():
positive_inputs = tokenizer(
prompt,
return_tensors="pt",
padding="max_length",
max_length=512,
truncation=True,
).to(device)
positive_embeddings = text_model.encode_texts(
positive_inputs.input_ids, positive_inputs.attention_mask
)
if positive_embeddings.ndim == 2:
positive_embeddings = positive_embeddings.unsqueeze(1)
positive_embeddings = positive_embeddings.to(device, dtype=dtype)
negative_inputs = tokenizer(
negative_prompt,
return_tensors="pt",
padding="max_length",
max_length=512,
truncation=True,
).to(device)
negative_embeddings = text_model.encode_texts(negative_inputs.input_ids, negative_inputs.attention_mask)
if negative_embeddings.ndim == 2:
negative_embeddings = negative_embeddings.unsqueeze(1)
negative_embeddings = negative_embeddings.to(device, dtype=dtype)
return torch.cat([negative_embeddings, positive_embeddings], dim=0)
def generate_latents(embeddings, height=576, width=576, num_inference_steps=50, guidance_scale=5.5):
with torch.no_grad():
device, dtype = embeddings.device, embeddings.dtype
half = embeddings.shape[0] // 2
latent_shape = (half, 16, height // 8, width // 8)
latents = torch.randn(latent_shape, device=device, dtype=dtype)
embeddings = embeddings.repeat_interleave(half, dim=0)
scheduler.set_timesteps(num_inference_steps)
for t in tqdm(scheduler.timesteps, desc="Генерация"):
latent_model_input = torch.cat([latents] * 2)
latent_model_input = scheduler.scale_model_input(latent_model_input, t)
noise_pred = unet(latent_model_input, t, embeddings).sample
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = scheduler.step(noise_pred, t, latents).prev_sample
return latents
def decode_latents(latents, vae, output_type="pil"):
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
with torch.no_grad():
images = vae.decode(latents).sample
images = (images / 2 + 0.5).clamp(0, 1)
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
images = (images * 255).round().astype("uint8")
images = [Image.fromarray(image) for image in images]
return images
# Example usage:
if __name__ == "__main__":
device = "cuda"
dtype = torch.float16
prompt = "кот"
negative_prompt = "bad quality"
tokenizer = AutoTokenizer.from_pretrained("visheratin/mexma-siglip")
text_model = AutoModel.from_pretrained(
"visheratin/mexma-siglip", torch_dtype=dtype, trust_remote_code=True
).to(device, dtype=dtype).eval()
embeddings = encode_prompt(prompt, negative_prompt, device, dtype)
pipeid = "AiArtLab/sdxs"
variant = "fp16"
unet = UNet2DConditionModel.from_pretrained(pipeid, subfolder="unet", variant=variant).to(device, dtype=dtype).eval()
vae = AutoencoderKL.from_pretrained(pipeid, subfolder="vae", variant=variant).to(device, dtype=dtype).eval()
scheduler = DDPMScheduler.from_pretrained(pipeid, subfolder="scheduler")
height, width = 576, 576
num_inference_steps = 40
output_folder, project_name = "samples", "sdxs"
latents = generate_latents(
embeddings=embeddings,
height=height,
width=width,
num_inference_steps = num_inference_steps
)
images = decode_latents(latents, vae)
os.makedirs(output_folder, exist_ok=True)
for idx, image in enumerate(images):
image.save(f"{output_folder}/{project_name}_{idx}.jpg")
print("Images generated and saved to:", output_folder)
✨ 主要特性
- 輕量快速:專為消費級顯卡設計,可在 16GB GPU 上完成全量訓練,訓練速度快。
- 多語言支持:採用多語言編碼器 Mexma - SigLIP,支持 80 種語言。
- 高質量生成:雖模型體積小,但生成質量有望與 SDXL 相媲美。
- 多模態潛力:可同步文本嵌入和圖像,支持圖像嵌入與文本描述混合查詢。
📚 詳細文檔
簡介
我們是 AiArtLab,一個預算有限的小團隊。我們的目標是創建一個緊湊且快速的模型,能夠在消費級顯卡上進行全量訓練(而非 LoRA)。我們選擇 U - Net 是因為它能夠有效處理小數據集,即使在 16GB GPU(如 RTX 4080)上也能快速訓練。由於預算僅數千美元,遠低於 SDXL 等競爭對手(數千萬美元),所以我們決定創建一個小而高效的模型,類似於 2015 年的 SD1.5 版本。
編碼器架構(文本和圖像)
我們對各種編碼器進行了實驗,發現像 LLaMA 或 T5 XXL 這樣的大型模型對於高質量生成並非必需。我們需要一個能夠理解查詢上下文的編碼器,更注重“提示理解”而非“提示跟隨”。因此,我們選擇了多語言編碼器 Mexma - SigLIP,它支持 80 種語言,並且處理句子而非單個標記。Mexma 最多接受 512 個標記,會生成一個大矩陣,從而減慢訓練速度。所以,我們使用了一個池化層,將 512x1152 的矩陣簡化為 1x1152 的向量。具體來說,我們將其通過一個線性模型/文本投影器,以實現與 SigLIP 嵌入的兼容性。這使我們能夠同步文本嵌入和圖像,有望實現統一的多模態模型。該功能允許在查詢中混合圖像嵌入和文本描述。此外,該模型可以僅使用圖像進行訓練,而無需文本描述。這將簡化視頻訓練(視頻標註具有挑戰性),並通過輸入帶有衰減的前一幀嵌入來實現更一致、無縫的視頻生成。未來,我們計劃將模型擴展到 3D/視頻生成。
U - Net 架構
我們選擇了平滑的通道金字塔:[384, 576, 768, 960],每個塊有兩層,以及 [4, 6, 8, 10] 個變換器,每個變換器有 1152/48 = 24 個注意力頭。這種架構在模型大小約為 20 億參數的情況下提供了最高的訓練速度(並且非常適合我的 RTX 4080)。我們相信,由於其更大的“深度”,儘管“尺寸”較小,但其質量將與 SDXL 相當。通過添加一個 1152 層,模型可以擴展到 40 億參數,實現與嵌入大小的完美對稱,我們認為這種對稱很優雅,並且可能達到“Flux/MJ 級別”的質量。
VAE 架構
我們選擇了非常規的 8x 16 通道 AuraDiffusion VAE,它能夠保留細節、文本和人體結構,沒有 SD3/Flux 那種“模糊”的問題。我們使用了帶有 FFN 卷積的快速版本,觀察到在精細圖案上有輕微的紋理損壞,這可能會降低其在基準測試中的評分。像 ESRGAN 這樣的超分辨率器可以解決這些偽影問題。總體而言,我們認為這個 VAE 被嚴重低估了。
訓練過程
優化器
我們測試了幾種優化器(AdamW、Laion、Optimi - AdamW、Adafactor 和 AdamW - 8bit),最終選擇了 AdamW - 8bit。Optimi - AdamW 的梯度衰減曲線最平滑,儘管 AdamW - 8bit 的表現更不穩定。然而,它的體積更小,允許更大的批量大小,從而在低成本 GPU 上最大化訓練速度(我們使用 4 個 A6000 和 5 個 L40 進行訓練)。
學習率
我們發現調整衰減/預熱曲線有一定效果,但並不顯著。最佳學習率往往被高估。我們的實驗表明,Adam 允許較寬的學習率範圍。我們從 1e - 4 開始,在訓練過程中逐漸降低到 1e - 6。換句話說,選擇正確的模型架構遠比調整超參數重要。
數據集
我們在大約 100 萬張圖像上訓練了該模型:在 256 分辨率的 ImageNet 上訓練了 60 個 epoch(由於低質量標註浪費了時間),在 CaptionEmporium/midjourney - niji - 1m - llavanext 上訓練了 8 個 epoch,以及在 576 分辨率的真實照片和動漫/藝術圖像上進行訓練。我們使用人工提示、Caption Emporium 提供的提示、SmilingWolf 的 WD - Tagger 和 Moondream2 進行標註,通過改變提示長度和組成來確保模型理解不同的提示風格。數據集非常小,導致模型錯過許多實體,並且在處理未見概念(如“自行車上的鵝”)時存在困難。數據集中還包含許多女僕風格的圖像,因為我們更關注模型學習人體結構的能力,而不是繪製“騎馬的宇航員”的技能。雖然大多數描述是英文的,但我們的測試表明該模型支持多語言。
侷限性
- 概念覆蓋有限:由於數據集極小,模型對概念的覆蓋有限。
- 圖像到圖像功能待完善:圖像到圖像功能需要進一步訓練(我們將 SigLIP 部分減少到 5%,以專注於文本到圖像的訓練)。
致謝
- Stan — 關鍵投資者。提供主要資金支持,感謝您在他人認為這是瘋狂之舉時對我們的信任。
- Captainsaturnus — 物質支持。
- Lovescape & Whargarbl — 精神支持。
- CaptionEmporium — 數據集提供方。
"我們相信未來在於高效、緊湊的模型。感謝您的捐贈,希望繼續得到您的支持。"
訓練預算
捐贈
如果您能提供 GPU 或資金用於訓練,請與我們聯繫。
狗狗幣(DOGE)地址:DEw2DR8C7BnF8GgcrfTzUjSnGkuMeJhg83 比特幣(BTC)地址:3JHv9Hb8kEW8zMAccdgCdZGfrHeMhH1rpN
聯繫方式
📄 許可證
本項目採用 Apache - 2.0 許可證。









