Skip to content

Commit

Permalink
adds xformers support to train_unconditional.py (open-mmlab#2520)
Browse files Browse the repository at this point in the history
  • Loading branch information
vvvm23 authored Mar 3, 2023
1 parent 7f0f7e1 commit 5e5ce13
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, is_accelerate_version, is_tensorboard_available, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
Expand Down Expand Up @@ -259,6 +260,9 @@ def parse_args():
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
Expand Down Expand Up @@ -410,6 +414,19 @@ def load_model_hook(models, input_dir):
model_config=model.config,
)

if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers

xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
model.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

# Initialize the scheduler
accepts_prediction_type = "prediction_type" in set(inspect.signature(DDPMScheduler.__init__).parameters.keys())
if accepts_prediction_type:
Expand Down

0 comments on commit 5e5ce13

Please sign in to comment.