diff --git a/__assets__/docs/animatediff.md b/__assets__/docs/animatediff.md index 63ff68ec..90e6937c 100644 --- a/__assets__/docs/animatediff.md +++ b/__assets__/docs/animatediff.md @@ -2,9 +2,11 @@ ## Setups for Inference ### Prepare Environment +**If you are using Ascend NPU, see instructions for [AscendNPU support](animatediff_npu.md).** ***We updated our inference code with xformers and a sequential decoding trick. Now AnimateDiff takes only ~12GB VRAM to inference, and run on a single RTX3090 !!*** + ``` git clone https://github.com/guoyww/AnimateDiff.git cd AnimateDiff diff --git a/__assets__/docs/animatediff_npu.md b/__assets__/docs/animatediff_npu.md new file mode 100644 index 00000000..821149e7 --- /dev/null +++ b/__assets__/docs/animatediff_npu.md @@ -0,0 +1,60 @@ +# Run AnimateDiff on AscendNPU + +## Prepare Environment + + +1. Clone this repository and Install package +```shell +git clone https://github.com/guoyww/AnimateDiff.git +cd AnimateDiff + +conda env create -f environment.yaml +conda activate animatediff +``` +2. Install Ascend Extension for PyTorch + +You can follow this [guide](https://www.hiascend.com/document/detail/en/ModelZoo/pytorchframework/ptes/ptes_00001.html) to download and install the Ascend NPU Firmware, Ascend NPU Driver, and CANN. Afterwards, you need to install additional Python packages. +```shell +pip3 install torch==2.1.0+cpu --index-url https://download.pytorch.org/whl/cpu # For X86 +pip3 install torch==2.1.0 # For Aarch64 +pip3 install accelerate==0.28.0 diffusers==0.11.1 decorator==5.1.1 scipy==1.12.0 attrs==23.2.0 torchvision==0.16.0 transformers==4.25.1 +``` +After installing the above Python packages, +You can follow this [README](https://github.com/Ascend/pytorch/blob/master/README.md) to install the torch_npu environment. +Then you can use AnimateDiff on Ascend NPU. + +## Prepare Checkpoints +You can follow this [README](animatediff.md) to prepare your checkpoints for inference, training and finetune. + +## Training/Finetune AnimateDiff on AscendNPU + +***Note: AscendNPU does not support xformers acceleration, so the option 'enable_xformers_memory_efficient_attention' in the yaml file under 'training/v1/' directory needs to be changed to False. I have integrated torch_npu flash attention and other acceleration methods into project that can speed up the training process.*** + +If you want to train animatediff on ascendnpu, you only add 'source' command on your shell scripts. + +As shown below: +```shell +# Firstly, add environment variables to the system via the 'source' command. +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/v1/image_finetune.yaml +``` + +## Inference AnimateDiff on AscendNPU + +If you want to inference animatediff on ascendnpu, you only add 'source' command on your shell scripts. + +As shown below: +```shell +# Firstly, add environment variables to the system via the 'source' command. +source /usr/local/Ascend/ascend-toolkit/set_env.sh + +python -m scripts.animate --config configs/prompts/1-ToonYou.yaml +python -m scripts.animate --config configs/prompts/2-Lyriel.yaml +python -m scripts.animate --config configs/prompts/3-RcnzCartoon.yaml +python -m scripts.animate --config configs/prompts/4-MajicMix.yaml +python -m scripts.animate --config configs/prompts/5-RealisticVision.yaml +python -m scripts.animate --config configs/prompts/6-Tusun.yaml +python -m scripts.animate --config configs/prompts/7-FilmVelvia.yaml +python -m scripts.animate --config configs/prompts/8-GhibliBackground.yaml +``` \ No newline at end of file diff --git a/animatediff/data/dataset.py b/animatediff/data/dataset.py index 3f6ec102..fdb54c5a 100644 --- a/animatediff/data/dataset.py +++ b/animatediff/data/dataset.py @@ -31,7 +31,7 @@ def __init__( sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) self.pixel_transforms = transforms.Compose([ transforms.RandomHorizontalFlip(), - transforms.Resize(sample_size[0]), + transforms.Resize(sample_size[0], antialias=None), transforms.CenterCrop(sample_size), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), ]) diff --git a/animatediff/models/attention.py b/animatediff/models/attention.py index ad23583c..b1e905ee 100644 --- a/animatediff/models/attention.py +++ b/animatediff/models/attention.py @@ -6,6 +6,15 @@ import torch import torch.nn.functional as F from torch import nn +from animatediff.utils.util import is_npu_available +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu + from animatediff.models.attention_npu_monkey_patch import replace_with_torch_npu_flash_attention + from animatediff.models.attention_npu_monkey_patch import replace_with_torch_npu_geglu + replace_with_torch_npu_flash_attention() + replace_with_torch_npu_geglu() + from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.modeling_utils import ModelMixin diff --git a/animatediff/models/attention_npu_monkey_patch.py b/animatediff/models/attention_npu_monkey_patch.py new file mode 100644 index 00000000..0f59e079 --- /dev/null +++ b/animatediff/models/attention_npu_monkey_patch.py @@ -0,0 +1,52 @@ +import math + +import torch +import torch.nn.functional as F +import torch_npu +import diffusers + + +def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + if query.dtype in (torch.float16, torch.bfloat16): + query = query.reshape(query.shape[0] // self.heads, self.heads, query.shape[1], query.shape[2]) + key = key.reshape(key.shape[0] // self.heads, self.heads, key.shape[1], key.shape[2]) + value = value.reshape(value.shape[0] // self.heads, self.heads, value.shape[1], value.shape[2]) + hidden_states = torch_npu.npu_fusion_attention( + query, key, value, self.heads, input_layout="BNSD", + pse=None, + atten_mask=attention_mask, + scale=1.0 / math.sqrt(query.shape[-1]), + pre_tockens=65536, + next_tockens=65536, + keep_prob=1, + sync=False, + inner_precise=0, + )[0] + + hidden_states = hidden_states.reshape(hidden_states.shape[0] * self.heads, hidden_states.shape[2], + hidden_states.shape[3]) + else: + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +def geglu_forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + return torch_npu.npu_geglu(hidden_states, dim=-1, approximate=1)[0] + + +def replace_with_torch_npu_flash_attention(): + diffusers.models.attention.CrossAttention._attention = _attention + + +def replace_with_torch_npu_geglu(): + diffusers.models.attention.GEGLU.forward = geglu_forward diff --git a/animatediff/pipelines/pipeline_animation.py b/animatediff/pipelines/pipeline_animation.py index bcc1ddb8..4a2d2074 100644 --- a/animatediff/pipelines/pipeline_animation.py +++ b/animatediff/pipelines/pipeline_animation.py @@ -26,7 +26,7 @@ from diffusers.utils import deprecate, logging, BaseOutput from einops import rearrange - +from animatediff.utils.util import is_npu_available from ..models.unet import UNet3DConditionModel from ..models.sparse_controlnet import SparseControlNetModel import pdb @@ -129,8 +129,10 @@ def enable_sequential_cpu_offload(self, gpu_id=0): from accelerate import cpu_offload else: raise ImportError("Please install accelerate via `pip install accelerate`") - - device = torch.device(f"cuda:{gpu_id}") + if is_npu_available(): + device = torch.device(f"npu:{gpu_id}") + else: + device = torch.device(f"cuda:{gpu_id}") for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: if cpu_offloaded_model is not None: diff --git a/animatediff/utils/util.py b/animatediff/utils/util.py index c0944837..8eae715a 100644 --- a/animatediff/utils/util.py +++ b/animatediff/utils/util.py @@ -1,5 +1,6 @@ import os import imageio +import importlib import numpy as np from typing import Union @@ -170,3 +171,17 @@ def load_weights( animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha) return animation_pipeline + +def is_npu_available(): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + return False + + import torch_npu + + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False \ No newline at end of file diff --git a/app.py b/app.py index d4df5db9..aa53b607 100644 --- a/app.py +++ b/app.py @@ -17,10 +17,18 @@ from animatediff.models.unet import UNet3DConditionModel from animatediff.pipelines.pipeline_animation import AnimationPipeline -from animatediff.utils.util import save_videos_grid +from animatediff.utils.util import save_videos_grid, is_npu_available from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu + device = "npu" +else: + device = "cuda" + + sample_idx = 0 scheduler_dict = { @@ -81,9 +89,9 @@ def refresh_personalized_model(self): def update_stable_diffusion(self, stable_diffusion_dropdown): self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer") - self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda() - self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda() - self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda() + self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").to(device) + self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").to(device) + self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).to(device) return gr.Dropdown.update() def update_motion_module(self, motion_module_dropdown): @@ -150,17 +158,17 @@ def animate( if base_model_dropdown == "": raise gr.Error(f"Please select a base DreamBooth model.") - if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() + if is_xformers_available() and not is_npu_available(): self.unet.enable_xformers_memory_efficient_attention() pipeline = AnimationPipeline( vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) - ).to("cuda") + ).to(device) if self.lora_model_state_dict != {}: pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider) - pipeline.to("cuda") + pipeline.to(device) if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) else: torch.seed() diff --git a/scripts/animate.py b/scripts/animate.py index 22f4c2af..3f09751e 100644 --- a/scripts/animate.py +++ b/scripts/animate.py @@ -17,7 +17,7 @@ from animatediff.models.sparse_controlnet import SparseControlNetModel from animatediff.pipelines.pipeline_animation import AnimationPipeline from animatediff.utils.util import save_videos_grid -from animatediff.utils.util import load_weights +from animatediff.utils.util import load_weights, is_npu_available from diffusers.utils.import_utils import is_xformers_available from einops import rearrange, repeat @@ -27,6 +27,13 @@ from PIL import Image import numpy as np +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu + device = "npu" +else: + device = "cuda" + @torch.no_grad() def main(args): @@ -42,8 +49,8 @@ def main(args): # create validation pipeline tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer") - text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").cuda() - vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").cuda() + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder").to(device) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae").to(device) sample_idx = 0 for model_idx, model_config in enumerate(config): @@ -52,7 +59,7 @@ def main(args): model_config.L = model_config.get("L", args.L) inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config)) - unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).cuda() + unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs)).to(device) # load controlnet model controlnet = controlnet_images = None @@ -71,7 +78,7 @@ def main(args): controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict controlnet_state_dict.pop("animatediff_config", "") controlnet.load_state_dict(controlnet_state_dict) - controlnet.cuda() + controlnet.to(device) image_paths = model_config.controlnet_images if isinstance(image_paths, str): image_paths = [image_paths] @@ -102,7 +109,7 @@ def image_norm(image): for i, image in enumerate(controlnet_images): Image.fromarray((255. * (image.numpy().transpose(1,2,0))).astype(np.uint8)).save(f"{savedir}/control_images/{i}.png") - controlnet_images = torch.stack(controlnet_images).unsqueeze(0).cuda() + controlnet_images = torch.stack(controlnet_images).unsqueeze(0).to(device) controlnet_images = rearrange(controlnet_images, "b f c h w -> b c f h w") if controlnet.use_simplified_condition_embedding: @@ -113,6 +120,8 @@ def image_norm(image): # set xformers if is_xformers_available() and (not args.without_xformers): + if is_npu_available(): + raise ValueError("AscendNPU does not support xformers acceleration.") unet.enable_xformers_memory_efficient_attention() if controlnet is not None: controlnet.enable_xformers_memory_efficient_attention() @@ -120,7 +129,7 @@ def image_norm(image): vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, controlnet=controlnet, scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)), - ).to("cuda") + ).to(device) pipeline = load_weights( pipeline, @@ -134,7 +143,7 @@ def image_norm(image): dreambooth_model_path = model_config.get("dreambooth_path", ""), lora_model_path = model_config.get("lora_model_path", ""), lora_alpha = model_config.get("lora_alpha", 0.8), - ).to("cuda") + ).to(device) prompts = model_config.prompt n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt diff --git a/train.py b/train.py index 094e419a..f31a9968 100644 --- a/train.py +++ b/train.py @@ -37,9 +37,15 @@ from animatediff.data.dataset import WebVid10M from animatediff.models.unet import UNet3DConditionModel from animatediff.pipelines.pipeline_animation import AnimationPipeline -from animatediff.utils.util import save_videos_grid, zero_rank_print - +from animatediff.utils.util import save_videos_grid, zero_rank_print, is_npu_available +if is_npu_available(): + import torch_npu + from torch_npu.contrib import transfer_to_npu + torch.npu.config.allow_internal_format = False + device = "npu" +else: + device = "cuda" def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs): """Initializes distributed environment.""" @@ -212,7 +218,7 @@ def main( # Enable xformers if enable_xformers_memory_efficient_attention: - if is_xformers_available(): + if is_xformers_available() and not is_npu_available(): unet.enable_xformers_memory_efficient_attention() else: raise ValueError("xformers is not available. Make sure it is installed correctly") @@ -270,7 +276,7 @@ def main( if not image_finetune: validation_pipeline = AnimationPipeline( unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, - ).to("cuda") + ).to(device) else: validation_pipeline = StableDiffusionPipeline.from_pretrained( pretrained_model_path,