From 5e5ce13e2f89ac45a0066cb3f369462a3cf1d9ef Mon Sep 17 00:00:00 2001 From: Alex McKinney <44398246+vvvm23@users.noreply.github.com> Date: Fri, 3 Mar 2023 17:35:59 +0000 Subject: [PATCH] adds `xformers` support to `train_unconditional.py` (#2520) --- .../train_unconditional.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index b77be033a6..b46b4edb36 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -24,6 +24,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available # Will error if the minimal version of diffusers is not installed. Remove at your own risks. @@ -259,6 +260,9 @@ def parse_args(): ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' ), ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -410,6 +414,19 @@ def load_model_hook(models, input_dir): model_config=model.config, ) + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + model.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + # Initialize the scheduler accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys()) if accepts_prediction_type: