Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support scaled_dot_product_attention in motion-module #26

Merged
merged 1 commit into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,25 +10,49 @@
[![arXiv](https://img.shields.io/badge/arXiv-2312.13964-b31b1b.svg)](https://arxiv.org/abs/2312.13964)
[![Project Page](https://img.shields.io/badge/PIA-Website-green)](https://pi-animator.github.io)
[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia)
<a target="_blank" href="https://huggingface.co/spaces/Leoxing/PIA">
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg" alt="Open in HugginFace"/>
</a>

You may also want to try other project from our team:
<a target="_blank" href="https://github.com/open-mmlab/mmagic">
<img src="https://github.com/open-mmlab/mmagic/assets/28132635/15aab910-f5c4-4b76-af9d-fe8eead1d930" height=20 alt="MMagic"/>
</a>

PIA is a personalized image animation method which can generate videos with **high motion controllability** and **strong text and image alignment**.


<img src="__assets__/image_animation/teaser/teaser.gif">

## What's New
[2023/12/28] PIA can animate a 1024x1024 image with just 16GB of GPU memory with `scaled_dot_product_attention`!

[2023/12/25] HuggingFace demo is available now! [🤗 Hub](https://huggingface.co/spaces/Leoxing/PIA/)

[2023/12/22] Release the model and demo of PIA. Try it to make your personalized movie!

- Online Demo on [OpenXLab](https://openxlab.org.cn/apps/detail/zhangyiming/PiaPia)
- Checkpoint on [Google Drive](https://drive.google.com/file/d/1RL3Fp0Q6pMD8PbGPULYUnvjqyRQXGHwN/view?usp=drive_link) or [![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models/detail/zhangyiming/PIA)

## Setup
### Prepare Environment

Use the following command to install Pytorch==2.0.0 and other dependencies:

```
conda env create -f environment-pt2.yaml
conda activate pia
```

If you want to use lower version of Pytorch (e.g. 1.13.1), you can use the following command:

```
conda env create -f environment.yaml
conda activate pia
```

We strongly recommand you to use Pytorch==2.0.0 which supports `scaled_dot_product_attention` for memory-efficient image animation.

### Download checkpoints
<li>Download the Stable Diffusion v1-5</li>

Expand Down
19 changes: 18 additions & 1 deletion animatediff/models/motion_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,10 @@ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_m
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states

def set_use_memory_efficient_attention_xformers(self, *args, **kwargs):
print('Set Xformers for MotionModule\'s Attention.')
self._use_memory_efficient_attention_xformers = True

def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
# TODO attention_mask
query = query.contiguous()
Expand All @@ -467,6 +471,14 @@ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states

def _memory_efficient_attention_pt20(self, query, key, value, attention_mask):
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
hidden_states = torch.nn.functional.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0, is_causal=False)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states


class VersatileAttention(CrossAttention):
def __init__(
Expand Down Expand Up @@ -532,7 +544,12 @@ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)

# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
if hasattr(F, 'scaled_dot_product_attention'):
# NOTE: pt20's scaled_dot_product_attention seems more memory efficient than
# xformers' memory_efficient_attention, set it as the first class citizen
hidden_states = self._memory_efficient_attention_pt20(query, key, value, attention_mask)
hidden_states = hidden_states.to(query.dtype)
elif self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
Expand Down
9 changes: 8 additions & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def animate(
cfg_scale_slider,
seed_textbox,
ip_adapter_scale,
max_size,
progress=gr.Progress(),
):
if not self.loaded:
Expand All @@ -191,7 +192,7 @@ def animate(
else:
torch.seed()
seed = torch.initial_seed()
init_img, h, w = preprocess_img(init_img)
init_img, h, w = preprocess_img(init_img, max_size)
sample = self.pipeline(
image=init_img,
prompt=prompt_textbox,
Expand Down Expand Up @@ -317,6 +318,10 @@ def update_personalized_model():
sample_step_slider = gr.Slider(
label="Sampling steps", value=25, minimum=10, maximum=100, step=1)

max_size_slider = gr.Slider(
label='Max size (The long edge of the input image will be resized to this value, larger value means slower inference speed)',
value=512, step=64, minimum=512, maximum=1024)

length_slider = gr.Slider(
label="Animation length", value=16, minimum=8, maximum=24, step=1)
cfg_scale_slider = gr.Slider(
Expand Down Expand Up @@ -379,6 +384,7 @@ def GenerationMode(motion_scale_silder, option):
cfg_scale_slider,
seed_textbox,
ip_adapter_scale,
max_size_slider
],
outputs=[result_video]
)
Expand All @@ -388,5 +394,6 @@ def GenerationMode(motion_scale_silder, option):

if __name__ == "__main__":
demo = ui()
demo.queue(3)
demo.launch(server_name=args.server_name,
server_port=args.port, share=args.share, allowed_paths=['pia.png'])
22 changes: 22 additions & 0 deletions environment-pt2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: pia
channels:
- pytorch
- nvidia
dependencies:
- python=3.10
- pytorch=2.0.0
- torchvision=0.15.0
- pytorch-cuda=11.8
- pip
- pip:
- diffusers==0.24.0
- transformers==4.25.1
- xformers
- imageio==2.33.1
- decord==0.6.0
- gdown
- einops
- omegaconf
- safetensors
- gradio
- wandb