Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AscendNPU support for Animatediff #335

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions __assets__/docs/animatediff.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
## Setups for Inference

### Prepare Environment
**If you are using Ascend NPU, see instructions for [AscendNPU support](animatediff_npu.md).**

***We updated our inference code with xformers and a sequential decoding trick. Now AnimateDiff takes only ~12GB VRAM to inference, and run on a single RTX3090 !!***


```
git clone https://github.com/guoyww/AnimateDiff.git
cd AnimateDiff
Expand Down
60 changes: 60 additions & 0 deletions __assets__/docs/animatediff_npu.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Run AnimateDiff on AscendNPU

## Prepare Environment


1. Clone this repository and Install package
```shell
git clone https://github.com/guoyww/AnimateDiff.git
cd AnimateDiff

conda env create -f environment.yaml
conda activate animatediff
```
2. Install Ascend Extension for PyTorch

You can follow this [guide](https://www.hiascend.com/document/detail/en/ModelZoo/pytorchframework/ptes/ptes_00001.html) to download and install the Ascend NPU Firmware, Ascend NPU Driver, and CANN. Afterwards, you need to install additional Python packages.
```shell
pip3 install torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu # For X86
pip3 install torch==2.1.0 # For Aarch64
pip3 install accelerate==0.28.0 diffusers==0.11.1 decorator==5.1.1 scipy==1.12.0 attrs==23.2.0 torchvision==0.16.0 transformers==4.25.1
```
After installing the above Python packages,
You can follow this [README](https://github.com/Ascend/pytorch/blob/master/README.md) to install the torch_npu environment.
Then you can use AnimateDiff on Ascend NPU.

## Prepare Checkpoints
You can follow this [README](animatediff.md) to prepare your checkpoints for inference, training and finetune.

## Training/Finetune AnimateDiff on AscendNPU

***Note: AscendNPU does not support xformers acceleration, so the option 'enable_xformers_memory_efficient_attention' in the yaml file under 'training/v1/' directory needs to be changed to <font color=red>False</font>. I have integrated torch_npu flash attention and other acceleration methods into project that can speed up the training process.***

If you want to train animatediff on ascendnpu, you only add 'source' command on your shell scripts.

As shown below:
```shell
# Firstly, add environment variables to the system via the 'source' command.
source /usr/local/Ascend/ascend-toolkit/set_env.sh

torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/v1/image_finetune.yaml
```

## Inference AnimateDiff on AscendNPU

If you want to inference animatediff on ascendnpu, you only add 'source' command on your shell scripts.

As shown below:
```shell
# Firstly, add environment variables to the system via the 'source' command.
source /usr/local/Ascend/ascend-toolkit/set_env.sh

python -m scripts.animate --config configs/prompts/1-ToonYou.yaml
python -m scripts.animate --config configs/prompts/2-Lyriel.yaml
python -m scripts.animate --config configs/prompts/3-RcnzCartoon.yaml
python -m scripts.animate --config configs/prompts/4-MajicMix.yaml
python -m scripts.animate --config configs/prompts/5-RealisticVision.yaml
python -m scripts.animate --config configs/prompts/6-Tusun.yaml
python -m scripts.animate --config configs/prompts/7-FilmVelvia.yaml
python -m scripts.animate --config configs/prompts/8-GhibliBackground.yaml
```
2 changes: 1 addition & 1 deletion animatediff/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
self.pixel_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(sample_size[0]),
transforms.Resize(sample_size[0], antialias=None),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
Expand Down
9 changes: 9 additions & 0 deletions animatediff/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
import torch
import torch.nn.functional as F
from torch import nn
from animatediff.utils.util import is_npu_available
if is_npu_available():
import torch_npu
from torch_npu.contrib import transfer_to_npu
from animatediff.models.attention_npu_monkey_patch import replace_with_torch_npu_flash_attention
from animatediff.models.attention_npu_monkey_patch import replace_with_torch_npu_geglu
replace_with_torch_npu_flash_attention()
replace_with_torch_npu_geglu()


from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin
Expand Down
52 changes: 52 additions & 0 deletions animatediff/models/attention_npu_monkey_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import math

import torch
import torch.nn.functional as F
import torch_npu
import diffusers


def _attention(self, query, key, value, attention_mask=None):
if self.upcast_attention:
query = query.float()
key = key.float()

if query.dtype in (torch.float16, torch.bfloat16):
query = query.reshape(query.shape[0] // self.heads, self.heads, query.shape[1], query.shape[2])
key = key.reshape(key.shape[0] // self.heads, self.heads, key.shape[1], key.shape[2])
value = value.reshape(value.shape[0] // self.heads, self.heads, value.shape[1], value.shape[2])
hidden_states = torch_npu.npu_fusion_attention(
query, key, value, self.heads, input_layout="BNSD",
pse=None,
atten_mask=attention_mask,
scale=1.0 / math.sqrt(query.shape[-1]),
pre_tockens=65536,
next_tockens=65536,
keep_prob=1,
sync=False,
inner_precise=0,
)[0]

hidden_states = hidden_states.reshape(hidden_states.shape[0] * self.heads, hidden_states.shape[2],
hidden_states.shape[3])
else:
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states


def geglu_forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0]


def replace_with_torch_npu_flash_attention():
diffusers.models.attention.CrossAttention._attention = _attention


def replace_with_torch_npu_geglu():
diffusers.models.attention.GEGLU.forward = geglu_forward
8 changes: 5 additions & 3 deletions animatediff/pipelines/pipeline_animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from diffusers.utils import deprecate, logging, BaseOutput

from einops import rearrange

from animatediff.utils.util import is_npu_available
from ..models.unet import UNet3DConditionModel
from ..models.sparse_controlnet import SparseControlNetModel
import pdb
Expand Down Expand Up @@ -129,8 +129,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0):
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")

device = torch.device(f"cuda:{gpu_id}")
if is_npu_available():
device = torch.device(f"npu:{gpu_id}")
else:
device = torch.device(f"cuda:{gpu_id}")

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
if cpu_offloaded_model is not None:
Expand Down
15 changes: 15 additions & 0 deletions animatediff/utils/util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import imageio
import importlib
import numpy as np
from typing import Union

Expand Down Expand Up @@ -170,3 +171,17 @@ def load_weights(
animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha)

return animation_pipeline

def is_npu_available():
"Checks if `torch_npu` is installed and potentially if a NPU is in the environment"
if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None:
return False

import torch_npu

try:
# Will raise a RuntimeError if no NPU is found
_ = torch.npu.device_count()
return torch.npu.is_available()
except RuntimeError:
return False
22 changes: 15 additions & 7 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@

from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import save_videos_grid
from animatediff.utils.util import save_videos_grid, is_npu_available
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora

if is_npu_available():
import torch_npu
from torch_npu.contrib import transfer_to_npu
device = "npu"
else:
device = "cuda"



sample_idx = 0
scheduler_dict = {
Expand Down Expand Up @@ -81,9 +89,9 @@ def refresh_personalized_model(self):

def update_stable_diffusion(self, stable_diffusion_dropdown):
self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer")
self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda()
self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda()
self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").to(device)
self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").to(device)
self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device)
return gr.Dropdown.update()

def update_motion_module(self, motion_module_dropdown):
Expand Down Expand Up @@ -150,17 +158,17 @@ def animate(
if base_model_dropdown == "":
raise gr.Error(f"Please select a base DreamBooth model.")

if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
if is_xformers_available() and not is_npu_available(): self.unet.enable_xformers_memory_efficient_attention()

pipeline = AnimationPipeline(
vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
).to("cuda")
).to(device)

if self.lora_model_state_dict != {}:
pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)

pipeline.to("cuda")
pipeline.to(device)

if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
else: torch.seed()
Expand Down
25 changes: 17 additions & 8 deletions scripts/animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from animatediff.models.sparse_controlnet import SparseControlNetModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import save_videos_grid
from animatediff.utils.util import load_weights
from animatediff.utils.util import load_weights, is_npu_available
from diffusers.utils.import_utils import is_xformers_available

from einops import rearrange, repeat
Expand All @@ -27,6 +27,13 @@
from PIL import Image
import numpy as np

if is_npu_available():
import torch_npu
from torch_npu.contrib import transfer_to_npu
device = "npu"
else:
device = "cuda"


@torch.no_grad()
def main(args):
Expand All @@ -42,8 +49,8 @@ def main(args):

# create validation pipeline
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").cuda()
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").cuda()
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").to(device)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device)

sample_idx = 0
for model_idx, model_config in enumerate(config):
Expand All @@ -52,7 +59,7 @@ def main(args):
model_config.L = model_config.get("L", args.L)

inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config))
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).cuda()
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).to(device)

# load controlnet model
controlnet = controlnet_images = None
Expand All @@ -71,7 +78,7 @@ def main(args):
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
controlnet_state_dict.pop("animatediff_config", "")
controlnet.load_state_dict(controlnet_state_dict)
controlnet.cuda()
controlnet.to(device)

image_paths = model_config.controlnet_images
if isinstance(image_paths, str): image_paths = [image_paths]
Expand Down Expand Up @@ -102,7 +109,7 @@ def image_norm(image):
for i, image in enumerate(controlnet_images):
Image.fromarray((255. * (image.numpy().transpose(1,2,0))).astype(np.uint8)).save(f"{savedir}/control_images/{i}.png")

controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda()
controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(device)
controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w")

if controlnet.use_simplified_condition_embedding:
Expand All @@ -113,14 +120,16 @@ def image_norm(image):

# set xformers
if is_xformers_available() and (not args.without_xformers):
if is_npu_available():
raise ValueError("AscendNPU does not support xformers acceleration.")
unet.enable_xformers_memory_efficient_attention()
if controlnet is not None: controlnet.enable_xformers_memory_efficient_attention()

pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
controlnet=controlnet,
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
).to("cuda")
).to(device)

pipeline = load_weights(
pipeline,
Expand All @@ -134,7 +143,7 @@ def image_norm(image):
dreambooth_model_path = model_config.get("dreambooth_path", ""),
lora_model_path = model_config.get("lora_model_path", ""),
lora_alpha = model_config.get("lora_alpha", 0.8),
).to("cuda")
).to(device)

prompts = model_config.prompt
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
Expand Down
14 changes: 10 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,15 @@
from animatediff.data.dataset import WebVid10M
from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import save_videos_grid, zero_rank_print

from animatediff.utils.util import save_videos_grid, zero_rank_print, is_npu_available

if is_npu_available():
import torch_npu
from torch_npu.contrib import transfer_to_npu
torch.npu.config.allow_internal_format = False
device = "npu"
else:
device = "cuda"

def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs):
"""Initializes distributed environment."""
Expand Down Expand Up @@ -212,7 +218,7 @@ def main(

# Enable xformers
if enable_xformers_memory_efficient_attention:
if is_xformers_available():
if is_xformers_available() and not is_npu_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
Expand Down Expand Up @@ -270,7 +276,7 @@ def main(
if not image_finetune:
validation_pipeline = AnimationPipeline(
unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
).to("cuda")
).to(device)
else:
validation_pipeline = StableDiffusionPipeline.from_pretrained(
pretrained_model_path,
Expand Down