Model Overview
Model Features
Model Capabilities
Use Cases
đ CSGO: Content-Style Composition in Text-to-Image Generation
This repository contains the official PyTorch implementation of our paper, focusing on content-style composition in text-to-image generation, offering multiple model weights and achieving various stylized synthesis capabilities.
đ Quick Start
1. Clone the code and prepare the environment
git clone https://github.com/instantX-research/CSGO
cd CSGO
# create env using conda
conda create -n CSGO python=3.9
conda activate CSGO
# install dependencies with pip
# for Linux and Windows users
pip install -r requirements.txt
2. Download pretrained weights(coming soon)
The easiest way to download the pretrained weights is from HuggingFace:
# first, ensure git-lfs is installed, see: https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage
git lfs install
# clone and move the weights
git clone https://huggingface.co/InstantX/CSGO
Our method is fully compatible with SDXL, VAE, ControlNet, and Image Encoder. Please download them and place them in the ./base_models folder.
tips:If you expect to load Controlnet directly using ControlNetPipeline as in CSGO, do the following:
git clone https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic
mv TTPLanet_SDXL_Controlnet_Tile_Realistic/TTPLANET_Controlnet_Tile_realistic_v2_fp16.safetensors TTPLanet_SDXL_Controlnet_Tile_Realistic/diffusion_pytorch_model.safetensors
3. Inference đ
import torch
from ip_adapter.utils import resize_content
import numpy as np
from ip_adapter.utils import BLOCKS as BLOCKS
from ip_adapter.utils import controlnet_BLOCKS as controlnet_BLOCKS
from PIL import Image
from diffusers import (
AutoencoderKL,
ControlNetModel,
StableDiffusionXLControlNetPipeline,
)
from ip_adapter import CSGO
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
base_model_path = "./base_models/stable-diffusion-xl-base-1.0"
image_encoder_path = "./base_models/IP-Adapter/sdxl_models/image_encoder"
csgo_ckpt = "./CSGO/csgo.bin"
pretrained_vae_name_or_path ='./base_models/sdxl-vae-fp16-fix'
controlnet_path = "./base_models/TTPLanet_SDXL_Controlnet_Tile_Realistic"
weight_dtype = torch.float16
vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16,use_safetensors=True)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
add_watermarker=False,
vae=vae
)
pipe.enable_vae_tiling()
target_content_blocks = BLOCKS['content']
target_style_blocks = BLOCKS['style']
controlnet_target_content_blocks = controlnet_BLOCKS['content']
controlnet_target_style_blocks = controlnet_BLOCKS['style']
csgo = CSGO(pipe, image_encoder_path, csgo_ckpt, device, num_content_tokens=4,num_style_tokens=32,
target_content_blocks=target_content_blocks, target_style_blocks=target_style_blocks,controlnet_adapter=True,
controlnet_target_content_blocks=controlnet_target_content_blocks,
controlnet_target_style_blocks=controlnet_target_style_blocks,
content_model_resampler=True,
style_model_resampler=True,
)
style_name = 'img_1.png'
content_name = 'img_0.png'
style_image = Image.open("../assets/{}".format(style_name)).convert('RGB')
content_image = Image.open('../assets/{}'.format(content_name)).convert('RGB')
caption ='a small house with a sheep statue on top of it'
num_sample=4
#image-driven style transfer
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
prompt=caption,
negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
content_scale=1.0,
style_scale=1.0,
guidance_scale=10,
num_images_per_prompt=num_sample,
num_samples=1,
num_inference_steps=50,
seed=42,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=0.6,
)
#text editing-driven stylized synthesis
caption='a small house'
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
prompt=caption,
negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
content_scale=1.0,
style_scale=1.0,
guidance_scale=10,
num_images_per_prompt=num_sample,
num_samples=1,
num_inference_steps=50,
seed=42,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=0.4,
)
#text-driven stylized synthesis
caption='a cat'
#If the content image still interferes with the generated results, set the content image to an empty image.
# content_image =Image.fromarray(np.zeros((content_image.size[0],content_image.size[1], 3), dtype=np.uint8)).convert('RGB')
images = csgo.generate(pil_content_image= content_image, pil_style_image=style_image,
prompt=caption,
negative_prompt= "text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
content_scale=1.0,
style_scale=1.0,
guidance_scale=10,
num_images_per_prompt=num_sample,
num_samples=1,
num_inference_steps=50,
seed=42,
image=content_image.convert('RGB'),
controlnet_conditioning_scale=0.01,
)
⨠Features
- Multiple Model Weights: We currently release two model weights,
csgo.bin
,csgo_4_32.bin
, andcsgo_4_32_v2.bin
, with different configurations of content and style tokens and additional features. - Stylized Synthesis Capabilities: Our CSGO achieves image-driven style transfer, text-driven stylized synthesis, and text editing-driven stylized synthesis.
đĻ Installation
The installation process mainly includes cloning the code, preparing the environment, and downloading the pretrained weights (coming soon). Please refer to the "Quick Start" section for detailed steps.
đģ Usage Examples
Basic Usage
The inference code in the "Quick Start" section demonstrates the basic usage of CSGO, including image-driven style transfer, text editing-driven stylized synthesis, and text-driven stylized synthesis.
đ Documentation
Detail
We currently release two model weights.
Property | Details |
---|---|
Model Type | We currently have three model weights: csgo.bin , csgo_4_32.bin , and csgo_4_32_v2.bin . |
Content Token | For csgo.bin , it's 4; for csgo_4_32.bin and csgo_4_32_v2.bin , it's also 4. |
Style Token | For csgo.bin , it's 16; for csgo_4_32.bin and csgo_4_32_v2.bin , it's 32. |
Other | csgo.bin has no additional information; csgo_4_32.bin uses Deepspeed zero2; csgo_4_32_v2.bin uses Deepspeed zero2 + more (coming soon). |
Pipeline
Capabilities
đĨ Our CSGO achieves image-driven style transfer, text-driven stylized synthesis, and text editing-driven stylized synthesis.
đĨ For more results, visit our homepage đĨ
Demos
đĨ For more results, visit our homepage đĨ
Content-Style Composition
Cycle Translation
Text-Driven Style Synthesis
Text Editing-Driven Style Synthesis
đ§ Technical Details
The project is an official PyTorch implementation of our paper CSGO: Content-Style Composition in Text-to-Image Generation. It is fully compatible with SDXL, VAE, ControlNet, and Image Encoder.
đ License
This project is under the Apache-2.0 license.
Acknowledgements
This project is developed by InstantX Team, all copyright reserved.
Citation đ
If you find CSGO useful for your research, welcome to đ this repo and cite our work using the following BibTeX:
@article{xing2024csgo,
title={CSGO: Content-Style Composition in Text-to-Image Generation},
author={Peng Xing and Haofan Wang and Yanpeng Sun and Qixun Wang and Xu Bai and Hao Ai and Renyuan Huang and Zechao Li},
year={2024},
journal = {arXiv 2408.16766},
}