From 6575f4c184d58f2e313843c1fd0859fd093c3728 Mon Sep 17 00:00:00 2001 From: LeoXing1996 Date: Thu, 28 Dec 2023 17:38:17 +0800 Subject: [PATCH] support scaled_dot_product_attention in motion-module + update README --- README.md | 24 ++++++++++++++++++++++++ animatediff/models/motion_module.py | 19 ++++++++++++++++++- app.py | 9 ++++++++- environment-pt2.yaml | 22 ++++++++++++++++++++++ 4 files changed, 72 insertions(+), 2 deletions(-) create mode 100644 environment-pt2.yaml diff --git a/README.md b/README.md index 2d96412..77df14e 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,14 @@ [![arXiv](https://img.shields.io/badge/arXiv-2312.13964-b31b1b.svg)](https://arxiv.org/abs/2312.13964) [![Project Page](https://img.shields.io/badge/PIA-Website-green)](https://pi-animator.github.io) [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia) + + Open in HugginFace + + +You may also want to try other project from our team: + + MMagic + PIA is a personalized image animation method which can generate videos with **high motion controllability** and **strong text and image alignment**. @@ -17,6 +25,10 @@ PIA is a personalized image animation method which can generate videos with **hi ## What's New +[2023/12/28] PIA can animate a 1024x1024 image with just 16GB of GPU memory with `scaled_dot_product_attention`! + +[2023/12/25] HuggingFace demo is available now! [🤗 Hub](https://huggingface.co/spaces/Leoxing/PIA/) + [2023/12/22] Release the model and demo of PIA. Try it to make your personalized movie! - Online Demo on [OpenXLab](https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia) @@ -24,11 +36,23 @@ PIA is a personalized image animation method which can generate videos with **hi ## Setup ### Prepare Environment + +Use the following command to install Pytorch==2.0.0 and other dependencies: + +``` +conda env create -f environment-pt2.yaml +conda activate pia +``` + +If you want to use lower version of Pytorch (e.g. 1.13.1), you can use the following command: + ``` conda env create -f environment.yaml conda activate pia ``` +We strongly recommand you to use Pytorch==2.0.0 which supports `scaled_dot_product_attention` for memory-efficient image animation. + ### Download checkpoints
  • Download the Stable Diffusion v1-5
  • diff --git a/animatediff/models/motion_module.py b/animatediff/models/motion_module.py index 67a7b79..8a41312 100644 --- a/animatediff/models/motion_module.py +++ b/animatediff/models/motion_module.py @@ -458,6 +458,10 @@ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_m hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states + def set_use_memory_efficient_attention_xformers(self, *args, **kwargs): + print('Set Xformers for MotionModule\'s Attention.') + self._use_memory_efficient_attention_xformers = True + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): # TODO attention_mask query = query.contiguous() @@ -467,6 +471,14 @@ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states + def _memory_efficient_attention_pt20(self, query, key, value, attention_mask): + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0, is_causal=False) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + class VersatileAttention(CrossAttention): def __init__( @@ -532,7 +544,12 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) # attention, what we cannot get enough of - if self._use_memory_efficient_attention_xformers: + if hasattr(F, 'scaled_dot_product_attention'): + # NOTE: pt20's scaled_dot_product_attention seems more memory efficient than + # xformers' memory_efficient_attention, set it as the first class citizen + hidden_states = self._memory_efficient_attention_pt20(query, key, value, attention_mask) + hidden_states = hidden_states.to(query.dtype) + elif self._use_memory_efficient_attention_xformers: hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) # Some versions of xformers return output in fp32, cast it back to the dtype of the input hidden_states = hidden_states.to(query.dtype) diff --git a/app.py b/app.py index cfd1dbe..122017b 100644 --- a/app.py +++ b/app.py @@ -181,6 +181,7 @@ def animate( cfg_scale_slider, seed_textbox, ip_adapter_scale, + max_size, progress=gr.Progress(), ): if not self.loaded: @@ -191,7 +192,7 @@ def animate( else: torch.seed() seed = torch.initial_seed() - init_img, h, w = preprocess_img(init_img) + init_img, h, w = preprocess_img(init_img, max_size) sample = self.pipeline( image=init_img, prompt=prompt_textbox, @@ -317,6 +318,10 @@ def update_personalized_model(): sample_step_slider = gr.Slider( label="Sampling steps", value=25, minimum=10, maximum=100, step=1) + max_size_slider = gr.Slider( + label='Max size (The long edge of the input image will be resized to this value, larger value means slower inference speed)', + value=512, step=64, minimum=512, maximum=1024) + length_slider = gr.Slider( label="Animation length", value=16, minimum=8, maximum=24, step=1) cfg_scale_slider = gr.Slider( @@ -379,6 +384,7 @@ def GenerationMode(motion_scale_silder, option): cfg_scale_slider, seed_textbox, ip_adapter_scale, + max_size_slider ], outputs=[result_video] ) @@ -388,5 +394,6 @@ def GenerationMode(motion_scale_silder, option): if __name__ == "__main__": demo = ui() + demo.queue(3) demo.launch(server_name=args.server_name, server_port=args.port, share=args.share, allowed_paths=['pia.png']) diff --git a/environment-pt2.yaml b/environment-pt2.yaml new file mode 100644 index 0000000..d2106b6 --- /dev/null +++ b/environment-pt2.yaml @@ -0,0 +1,22 @@ +name: pia +channels: + - pytorch + - nvidia +dependencies: + - python=3.10 + - pytorch=2.0.0 + - torchvision=0.15.0 + - pytorch-cuda=11.8 + - pip + - pip: + - diffusers==0.24.0 + - transformers==4.25.1 + - xformers + - imageio==2.33.1 + - decord==0.6.0 + - gdown + - einops + - omegaconf + - safetensors + - gradio + - wandb