🚀 🤖 多输入ResShift扩散视频帧插值(Multi‑Input ResShift Diffusion VFI)
多输入ResShift扩散视频帧插值模型主要用于视频帧插值任务,能够在动画、视频等场景中,根据已有帧生成中间帧,同时还支持不确定性估计,为视频处理提供了更丰富的功能和更准确的结果。
🚀 快速开始
环境搭建
首先,直接从GitHub下载源代码:
git clone https://github.com/VicFonch/Multi-Input-Resshift-Diffusion-VFI.git
创建一个conda环境并安装所有依赖项:
conda create -n multi-input-resshift python=3.12
conda activate multi-input-resshift
pip install -r requirements.txt
⚠️ 重要提示
请确保你的系统与 CUDA 12.4 兼容。如果不兼容,请根据你当前的CUDA版本安装 CuPy。
推理示例
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from torchvision.transforms import Compose, ToTensor, Resize, Normalize
from utils.utils import denorm
from model.hub import MultiInputResShiftHub
model = MultiInputResShiftHub.from_pretrained("vfontech/Multiple-Input-Resshift-VFI").cuda()
model.eval()
img0_path = r"_data\example_images\frame1.png"
img2_path = r"_data\example_images\frame3.png"
mean = std = [0.5]*3
transforms = Compose([
Resize((256, 448)),
ToTensor(),
Normalize(mean=mean, std=std),
])
img0 = transforms(Image.open(img0_path).convert("RGB")).unsqueeze(0).cuda()
img2 = transforms(Image.open(img2_path).convert("RGB")).unsqueeze(0).cuda()
tau = 0.5
img1 = model.reverse_process([img0, img2], tau)
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.imshow(denorm(img0, mean=mean, std=std).squeeze().permute(1, 2, 0).cpu().numpy())
plt.subplot(1, 3, 2)
plt.imshow(denorm(img1, mean=mean, std=std).squeeze().permute(1, 2, 0).cpu().numpy())
plt.subplot(1, 3, 3)
plt.imshow(denorm(img2, mean=mean, std=std).squeeze().permute(1, 2, 0).cpu().numpy())
plt.show()
📄 许可证
本项目采用MIT许可证。
属性 |
详情 |
模型类型 |
PyTorch模型 |
标签 |
pytorch_model_hub_mixin、animation、video-frame-interpolation、uncertainty-estimation |
任务类型 |
图像到图像 |