forked from Tencent/MimicMotion
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Isaac Gu
committed
Jul 1, 2024
1 parent
ccb8486
commit 1c91d48
Showing
33 changed files
with
3,090 additions
and
215 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,95 @@ | ||
# MimicMotion | ||
A Pose-Guided Framework for Generating High-quality Videos of Arbitrary Length with Any Human Motion | ||
|
||
<a href='http://tencent.github.io/MimicMotion'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2406.19680'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> | ||
|
||
MimicMotion: High-Quality Human Motion Video Generation with Confidence-aware Pose Guidance | ||
<br/> | ||
*Yuang Zhang<sup>1,2</sup>, Jiaxi Gu<sup>1</sup>, Li-Wen Wang<sup>1</sup>, Han Wang<sup>1,2</sup>, Junqi Cheng<sup>1</sup>, Yuefeng Zhu<sup>1</sup>, Fangyuan Zou<sup>1</sup>* | ||
<br/> | ||
[<sup>1</sup>Tencent; <sup>2</sup>Shanghai Jiao Tong University] | ||
|
||
<p align="center"> | ||
<img src="assets/figures/preview_1.gif" width="100" /> | ||
<img src="assets/figures/preview_2.gif" width="100" /> | ||
<img src="assets/figures/preview_3.gif" width="100" /> | ||
<img src="assets/figures/preview_4.gif" width="100" /> | ||
<img src="assets/figures/preview_5.gif" width="100" /> | ||
<img src="assets/figures/preview_6.gif" width="100" /> | ||
<br/> | ||
<span>Highlights: <b>rich details</b>, <b> good temporal smoothness</b>, and <b>long video length</b>. </span> | ||
</p> | ||
|
||
## Overview | ||
|
||
<p align="center"> | ||
<img src="assets/figures/model_structure.png" alt="model architecture" width="640"/> | ||
</br> | ||
<i>An overview of the framework of MimicMotion.</i> | ||
</p> | ||
|
||
In recent years, generative artificial intelligence has achieved significant advancements in the field of image generation, spawning a variety of applications. However, video generation still faces considerable challenges in various aspects such as controllability, video length, and richness of details, which hinder the application and popularization of this technology. In this work, we propose a controllable video generation framework, dubbed *MimicMotion*, which can generate high-quality videos of arbitrary length with any motion guidance. Comparing with previous methods, our approach has several highlights. Firstly, with confidence-aware pose guidance, temporal smoothness can be achieved so model robustness can be enhanced with large-scale training data. Secondly, regional loss amplification based on pose confidence significantly eases the distortion of image significantly. Lastly, for generating long smooth videos, a progressive latent fusion strategy is proposed. By this means, videos of arbitrary length can be generated with acceptable resource consumption. With extensive experiments and user studies, MimicMotion demonstrates significant improvements over previous approaches in multiple aspects. | ||
|
||
## Quickstart | ||
|
||
### Environment setup | ||
|
||
Recommend python 3+ with torch 2.x are validated with an Nvidia V100 GPU. Follow the command below to install all the dependencies of python: | ||
|
||
``` | ||
conda env create -f environment.yaml | ||
conda activate mimicmotion | ||
``` | ||
|
||
### Download weights | ||
Please download weights manually as follows: | ||
``` | ||
cd MimicMotions/ | ||
mkdir models | ||
``` | ||
1. Download SVD model: [stabilityai/stable-video-diffusion-img2vid-xt-1-1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1) | ||
``` | ||
git lfs install | ||
git clone https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1 | ||
mkdir -p models/SVD | ||
mv stable-video-diffusion-img2vid-xt-1-1 models/SVD/ | ||
``` | ||
2. Download DWPose pretrained model: [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main) | ||
``` | ||
git lfs install | ||
git clone https://huggingface.co/yzd-v/DWPose | ||
mv DWPose models/ | ||
``` | ||
3. Download the pre-trained checkpoint of MimicMotion from [Huggingface](https://huggingface.co/ixaac/MimicMotion) | ||
``` | ||
curl -o models/MimicMotion.pth https://huggingface.co/ixaac/MimicMotion/resolve/main/MimicMotion.pth | ||
``` | ||
Finally, all the weights should be organized in models as follows | ||
``` | ||
models/ | ||
├── DWPose | ||
│ ├── dw-ll_ucoco_384.onnx | ||
│ └── yolox_l.onnx | ||
├── SVD | ||
│ └──stable-video-diffusion-img2vid-xt-1-1 | ||
└── MimicMotion.pth | ||
``` | ||
### Model inference | ||
We provide the inference script. | ||
``` | ||
python inference.py --inference_config configs/test.yaml | ||
``` | ||
## Citation | ||
```bib | ||
@article{mimicmotion2024, | ||
title={MimicMotion: High-Quality Human Motion Video Generation with Confidence-aware Pose Guidance}, | ||
author={Yuang Zhang and Jiaxi Gu and Li-Wen Wang and Han Wang and Junqi Cheng and Yuefeng Zhu and Fangyuan Zou}, | ||
journal={arXiv preprint arXiv:2406.19680}, | ||
year={2024} | ||
} | ||
``` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# w/h apsect ratio | ||
ASPECT_RATIO = 9 / 16 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
name: mimicmotion | ||
channels: | ||
- pytorch | ||
- nvidia | ||
dependencies: | ||
- python=3.11 | ||
- pytorch=2.0.1 | ||
- torchvision=0.15.2 | ||
- pytorch-cuda=11.7 | ||
- pip | ||
- pip: | ||
- diffusers==0.27.0 | ||
- transformers==4.32.1 | ||
- decord==0.6.0 | ||
- einops | ||
- omegaconf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
import os | ||
import argparse | ||
import logging | ||
import math | ||
from omegaconf import OmegaConf | ||
from datetime import datetime | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import torch.jit | ||
from torchvision.datasets.folder import pil_loader | ||
from torchvision.transforms.functional import pil_to_tensor, resize, center_crop | ||
from torchvision.transforms.functional import to_pil_image | ||
|
||
|
||
from constants import ASPECT_RATIO | ||
|
||
from mimicmotion.pipelines.pipeline_mimicmotion import MimicMotionPipeline | ||
from mimicmotion.utils.loader import create_pipeline | ||
from mimicmotion.utils.utils import save_to_mp4 | ||
from mimicmotion.dwpose.preprocess import get_video_pose, get_image_pose | ||
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s: [%(levelname)s] %(message)s") | ||
logger = logging.getLogger(__name__) | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
|
||
def preprocess(video_path, image_path, resolution=576, sample_stride=2): | ||
"""preprocess ref image pose and video pose | ||
Args: | ||
video_path (str): input video pose path | ||
image_path (str): reference image path | ||
resolution (int, optional): Defaults to 576. | ||
sample_stride (int, optional): Defaults to 2. | ||
""" | ||
image_pixels = pil_loader(image_path) | ||
image_pixels = pil_to_tensor(image_pixels) # (c, h, w) | ||
h, w = image_pixels.shape[-2:] | ||
############################ compute target h/w according to original aspect ratio ############################### | ||
if h>w: | ||
w_target, h_target = resolution, int(resolution / ASPECT_RATIO // 64) * 64 | ||
else: | ||
w_target, h_target = int(resolution / ASPECT_RATIO // 64) * 64, resolution | ||
h_w_ratio = float(h) / float(w) | ||
if h_w_ratio < h_target / w_target: | ||
h_resize, w_resize = h_target, math.ceil(h_target / h_w_ratio) | ||
else: | ||
h_resize, w_resize = math.ceil(w_target * h_w_ratio), w_target | ||
image_pixels = resize(image_pixels, [h_resize, w_resize], antialias=None) | ||
image_pixels = center_crop(image_pixels, [h_target, w_target]) | ||
image_pixels = image_pixels.permute((1, 2, 0)).numpy() | ||
##################################### get image&video pose value ################################################# | ||
image_pose = get_image_pose(image_pixels) | ||
video_pose = get_video_pose(video_path, image_pixels, sample_stride=sample_stride) | ||
pose_pixels = np.concatenate([np.expand_dims(image_pose, 0), video_pose]) | ||
image_pixels = np.transpose(np.expand_dims(image_pixels, 0), (0, 3, 1, 2)) | ||
return torch.from_numpy(pose_pixels.copy()) / 127.5 - 1, torch.from_numpy(image_pixels) / 127.5 - 1 | ||
|
||
|
||
def run_pipeline(pipeline: MimicMotionPipeline, image_pixels, pose_pixels, device, task_config): | ||
image_pixels = [to_pil_image(img.to(torch.uint8)) for img in (image_pixels + 1.0) * 127.5] | ||
pose_pixels = pose_pixels.unsqueeze(0).to(device) | ||
generator = torch.Generator(device=device) | ||
generator.manual_seed(task_config.seed) | ||
frames = pipeline( | ||
image_pixels, image_pose=pose_pixels, num_frames=pose_pixels.size(1), | ||
tile_size=task_config.num_frames, tile_overlap=task_config.frames_overlap, | ||
height=pose_pixels.shape[-2], width=pose_pixels.shape[-1], fps=7, | ||
noise_aug_strength=task_config.noise_aug_strength, num_inference_steps=task_config.num_inference_steps, | ||
generator=generator, min_guidance_scale=task_config.guidance_scale, | ||
max_guidance_scale=task_config.guidance_scale, decode_chunk_size=8, output_type="pt", device=device | ||
).frames.cpu() | ||
video_frames = (frames * 255.0).to(torch.uint8) | ||
|
||
for vid_idx in range(video_frames.shape[0]): | ||
# deprecated first frame because of ref image | ||
_video_frames = video_frames[vid_idx, 1:] | ||
|
||
return _video_frames | ||
|
||
|
||
@torch.no_grad() | ||
def main(args): | ||
if not args.no_use_float16 : | ||
torch.set_default_dtype(torch.float16) | ||
|
||
infer_config = OmegaConf.load(args.inference_config) | ||
pipeline = create_pipeline(infer_config, device) | ||
|
||
for task in infer_config.test_case: | ||
############################################## Pre-process data ############################################## | ||
pose_pixels, image_pixels = preprocess( | ||
task.ref_video_path, task.ref_image_path, | ||
resolution=task.resolution, sample_stride=task.sample_stride | ||
) | ||
########################################### Run MimicMotion pipeline ########################################### | ||
_video_frames = run_pipeline( | ||
pipeline, | ||
image_pixels, pose_pixels, | ||
device, task | ||
) | ||
################################### save results to output folder. ########################################### | ||
save_to_mp4( | ||
_video_frames, | ||
f"{args.output_dir}/{os.path.basename(task.ref_video_path).split('.')[0]}" \ | ||
f"_{datetime.now().strftime('%Y%m%d%H%M%S')}.mp4", | ||
fps=task.fps, | ||
) | ||
|
||
def set_logger(log_file=None, log_level=logging.INFO): | ||
log_handler = logging.FileHandler(log_file, "w") | ||
log_handler.setFormatter( | ||
logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s") | ||
) | ||
log_handler.setLevel(log_level) | ||
logger.addHandler(log_handler) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--log_file", type=str, default=None) | ||
parser.add_argument("--inference_config", type=str, default="configs/test.yaml") #ToDo | ||
parser.add_argument("--output_dir", type=str, default="outputs/", help="path to output") | ||
parser.add_argument("--no_use_float16", | ||
action="store_true", | ||
help="Whether use float16 to speed up inference", | ||
) | ||
args = parser.parse_args() | ||
|
||
Path(args.output_dir).mkdir(parents=True, exist_ok=True) | ||
set_logger(args.log_file \ | ||
if args.log_file is not None else f"{args.output_dir}/{datetime.now().strftime('%Y%m%d%H%M%S')}.log") | ||
main(args) | ||
logger.info(f"--- Finished ---") | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
*.pyc |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import os | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from .wholebody import Wholebody | ||
|
||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" | ||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
class DWposeDetector: | ||
""" | ||
A pose detect method for image-like data. | ||
Parameters: | ||
model_det: (str) serialized ONNX format model path, | ||
such as https://huggingface.co/yzd-v/DWPose/blob/main/yolox_l.onnx | ||
model_pose: (str) serialized ONNX format model path, | ||
such as https://huggingface.co/yzd-v/DWPose/blob/main/dw-ll_ucoco_384.onnx | ||
device: (str) 'cpu' or 'cuda:{device_id}' | ||
""" | ||
def __init__(self, model_det, model_pose, device='cpu'): | ||
self.pose_estimation = Wholebody(model_det=model_det, model_pose=model_pose, device=device) | ||
|
||
def __call__(self, oriImg): | ||
oriImg = oriImg.copy() | ||
H, W, C = oriImg.shape | ||
with torch.no_grad(): | ||
candidate, score = self.pose_estimation(oriImg) | ||
nums, _, locs = candidate.shape | ||
candidate[..., 0] /= float(W) | ||
candidate[..., 1] /= float(H) | ||
body = candidate[:, :18].copy() | ||
body = body.reshape(nums * 18, locs) | ||
subset = score[:, :18].copy() | ||
for i in range(len(subset)): | ||
for j in range(len(subset[i])): | ||
if subset[i][j] > 0.3: | ||
subset[i][j] = int(18 * i + j) | ||
else: | ||
subset[i][j] = -1 | ||
|
||
# un_visible = subset < 0.3 | ||
# candidate[un_visible] = -1 | ||
|
||
# foot = candidate[:, 18:24] | ||
|
||
faces = candidate[:, 24:92] | ||
|
||
hands = candidate[:, 92:113] | ||
hands = np.vstack([hands, candidate[:, 113:]]) | ||
|
||
faces_score = score[:, 24:92] | ||
hands_score = np.vstack([score[:, 92:113], score[:, 113:]]) | ||
|
||
bodies = dict(candidate=body, subset=subset, score=score[:, :18]) | ||
pose = dict(bodies=bodies, hands=hands, hands_score=hands_score, faces=faces, faces_score=faces_score) | ||
|
||
return pose | ||
|
||
dwpose_detector = DWposeDetector( | ||
model_det="models/DWPose/yolox_l.onnx", | ||
model_pose="models/DWPose/dw-ll_ucoco_384.onnx", | ||
device=device) |
Oops, something went wrong.