đ Control LoRA for Image Editing with THUDM/CogView4-6B
This project provides a Control LoRA that allows for minor edits to images using the THUDM/CogView4-6B model. It offers a practical solution for making specific style and setting changes to images.
⨠Features
- Base Model: Utilizes the
THUDM/CogView4-6B
model as the foundation.
- Datasets: Trained on the
sayapaul/OmniEdit-mini
dataset.
- Library: Built with the
diffusers
library.
- Multiple Editing Examples: Demonstrates various editing scenarios, such as changing to impasto painting style, spring settings, and stormy space settings.
đĻ Installation
No specific installation steps are provided in the original document.
đģ Usage Examples
Basic Usage
import torch
from diffusers import CogView4Pipeline
from diffusers.utils import load_image
from finetrainers.models.utils import _expand_linear_with_zeroed_weights
from finetrainers.patches import load_lora_weights
from finetrainers.patches.dependencies.diffusers.control import control_channel_concat
dtype = torch.bfloat16
device = torch.device("cuda")
generator = torch.Generator().manual_seed(0)
pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=dtype)
in_channels = pipe.transformer.config.in_channels
patch_channels = pipe.transformer.patch_embed.proj.in_features
pipe.transformer.patch_embed.proj = _expand_linear_with_zeroed_weights(pipe.transformer.patch_embed.proj, new_in_features=2 * patch_channels)
load_lora_weights(pipe, "finetrainers/CogView4-6B-Edit-LoRA-v0", "cogview4-lora")
pipe.set_adapters("cogview4-lora", 0.9)
pipe.to(device)
prompt = "Make the image look like it's from an ancient Egyptian mural."
control_image = load_image("examples/training/control/cogview4/omni_edit/validation_dataset/0.png")
height, width = 1024, 1024
with torch.no_grad():
latents = pipe.prepare_latents(1, in_channels, height, width, dtype, device, generator)
control_image = pipe.image_processor.preprocess(control_image, height=height, width=width)
control_image = control_image.to(device=device, dtype=dtype)
control_latents = pipe.vae.encode(control_image).latent_dist.sample(generator=generator)
control_latents = (control_latents - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
with control_channel_concat(pipe.transformer, ["hidden_states"], [control_latents], dims=[1]):
image = pipe(prompt, latents=latents, num_inference_steps=30, generator=generator).images[0]
image.save("output.png")
đ Documentation
đ§ Technical Details
No detailed technical implementation details are provided in the original document.
đ License
No license information is provided in the original document.
â ī¸ Important Note
This is an experimental checkpoint and its poor generalization is well-known.