diff --git a/examples/stable-diffusion/text_to_image_generation.py b/examples/stable-diffusion/text_to_image_generation.py index bb4cdfb9ad..195e4ca7f3 100755 --- a/examples/stable-diffusion/text_to_image_generation.py +++ b/examples/stable-diffusion/text_to_image_generation.py @@ -220,6 +220,24 @@ def main(): default=0, help="Number of steps to capture for profiling.", ) + parser.add_argument( + "--unet_adapter_name_or_path", + default=None, + type=str, + help="Path to pre-trained model", + ) + parser.add_argument( + "--text_encoder_adapter_name_or_path", + default=None, + type=str, + help="Path to pre-trained model", + ) + parser.add_argument( + "--lora_id", + default=None, + type=str, + help="Path to lora id", + ) args = parser.parse_args() # Set image resolution @@ -311,6 +329,8 @@ def main(): controlnet=controlnet, **kwargs, ) + if args.lora_id: + pipeline.load_lora_weights(args.lora_id) # Set seed before running the model set_seed(args.seed) @@ -334,6 +354,8 @@ def main(): args.model_name_or_path, **kwargs, ) + if args.lora_id: + pipeline.load_lora_weights(args.lora_id) # Set seed before running the model set_seed(args.seed) @@ -358,8 +380,18 @@ def main(): args.model_name_or_path, **kwargs, ) - - # Set seed before running the model + if args.unet_adapter_name_or_path is not None: + from peft import PeftModel + + pipeline.unet = PeftModel.from_pretrained(pipeline.unet, args.unet_adapter_name_or_path) + pipeline.unet = pipeline.unet.merge_and_unload() + if args.text_encoder_adapter_name_or_path is not None: + from peft import PeftModel + + pipeline.text_encoder = PeftModel.from_pretrained( + pipeline.text_encoder, args.text_encoder_adapter_name_or_path + ) + pipeline.text_encoder = pipeline.text_encoder.merge_and_unload() set_seed(args.seed) outputs = pipeline( diff --git a/examples/stable-diffusion/training/README.md b/examples/stable-diffusion/training/README.md index a81e4b0ec1..38a52a62db 100644 --- a/examples/stable-diffusion/training/README.md +++ b/examples/stable-diffusion/training/README.md @@ -258,3 +258,172 @@ PT_HPU_MAX_COMPOUND_OP_SIZE=5 python train_text_to_image_sdxl.py \ > [!NOTE] > `--mediapipe` only works on Gaudi2. + + +## DreamBooth +DreamBooth is a method to personalize text-to-image models like Stable Diffusion given just a few (3~5) images of a subject. The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for Stable Diffusion. + +### Dog toy example + +Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example. + +Let's first download it locally: + +```python +from huggingface_hub import snapshot_download + +local_dir = "./dog" +snapshot_download( + "diffusers/dog-example", + local_dir=local_dir, repo_type="dataset", + ignore_patterns=".gitattributes", +) +``` + +### Full model finetune +And launch the multi-card training using: +```bash + +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export INSTANCE_DIR="dog" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="out" + +python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --class_data_dir=$CLASS_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --num_class_images=200 \ + --gradient_accumulation_steps=1 \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=800 \ + --mixed_precision=bf16 \ + --use_hpu_graphs_for_training \ + --use_hpu_graphs_for_inference \ + --gaudi_config_name Habana/stable-diffusion \ + full + +``` +Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. +According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time. + +### PEFT model finetune +We provide example for dreambooth to use lora/lokr/loha/oft to finetune unet or text encoder. + +**___Note: When using peft method we can use a much higher learning rate compared to vanilla dreambooth. Here we +use *1e-4* instead of the usual *5e-6*.___** + +Launch the multi-card training using: +```bash + +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export INSTANCE_DIR="dog" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="out" + +python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --class_data_dir=$CLASS_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --num_class_images=200 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=800 \ + --mixed_precision=bf16 \ + --use_hpu_graphs_for_training \ + --use_hpu_graphs_for_inference \ + --gaudi_config_name Habana/stable-diffusion \ + lora --unet_r 8 --unet_alpha 8 + +``` +Similar command could be applied to loha, lokr, oft. +You could check each adapter specific args by "--help", like you could use following command to check oft specific args. + +```bash +python3 train_dreambooth.py oft --help + +``` + +**___Note: oft could not work with hpu graphs mode. since "torch.inverse" need to fallback to cpu. +there's error like "cpu fallback is not supported during hpu graph capturing"___** + + +You could use text_to_image_generation.py to generate picture using the peft adapter like + +```bash +python ../text_to_image_generation.py \ + --model_name_or_path runwayml/stable-diffusion-v1-5 \ + --prompts "a sks dog" \ + --num_images_per_prompt 5 \ + --batch_size 1 \ + --image_save_dir /tmp/stable_diffusion_images \ + --use_habana \ + --use_hpu_graphs \ + --unet_adapter_name_or_path out/unet \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` + +### DreamBooth training example for Stable Diffusion XL +You could use the dog images as example as well. +You can launch training using: +```bash +export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0" +export INSTANCE_DIR="dog" +export OUTPUT_DIR="lora-trained-xl" +export VAE_PATH="stabilityai/sdxl-vae" + +python ../../gaudi_spawn.py --world_size 8 --use_mpi train_dreambooth_lora_sdxl.py \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --pretrained_vae_model_name_or_path=$VAE_PATH \ + --output_dir=$OUTPUT_DIR \ + --mixed_precision="bf16" \ + --instance_prompt="a photo of sks dog" \ + --resolution=1024 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --learning_rate=1e-4 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --max_train_steps=500 \ + --validation_prompt="A photo of sks dog in a bucket" \ + --validation_epochs=25 \ + --seed=0 \ + --use_hpu_graphs_for_inference \ + --use_hpu_graphs_for_training \ + --gaudi_config_name Habana/stable-diffusion + +``` + +You could use text_to_image_generation.py to generate picture using the peft adapter like + +```bash +python ../text_to_image_generation.py \ + --model_name_or_path stabilityai/stable-diffusion-xl-base-1.0 \ + --prompts "A picture of a sks dog in a bucket" \ + --num_images_per_prompt 5 \ + --batch_size 1 \ + --image_save_dir /tmp/stable_diffusion_xl_images \ + --use_habana \ + --use_hpu_graphs \ + --lora_id lora-trained-xl \ + --gaudi_config Habana/stable-diffusion \ + --bf16 +``` \ No newline at end of file diff --git a/examples/stable-diffusion/training/requirements.txt b/examples/stable-diffusion/training/requirements.txt index acdd70a4a0..7fb1748675 100644 --- a/examples/stable-diffusion/training/requirements.txt +++ b/examples/stable-diffusion/training/requirements.txt @@ -1 +1,2 @@ imagesize +peft == 0.10.0 diff --git a/examples/stable-diffusion/training/train_dreambooth.py b/examples/stable-diffusion/training/train_dreambooth.py new file mode 100644 index 0000000000..b34f3c12c5 --- /dev/null +++ b/examples/stable-diffusion/training/train_dreambooth.py @@ -0,0 +1,1357 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +""" +Training script for DreamBooth to Text-to-Image Diffusion Models +Adapted from the following source: +https://github.com/huggingface/peft/blob/608a90ded9985ee1c5912d738082bb1fd618902b/examples/stable_diffusion/train_dreambooth.py +""" + +import argparse +import gc +import hashlib +import itertools +import logging +import math +import os +import threading +import warnings +from pathlib import Path +from typing import Union + +import datasets +import diffusers +import habana_frameworks.torch.core as htcore +import numpy as np +import psutil +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) +from diffusers.optimization import get_scheduler +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module +from habana_frameworks.torch.hpu import memory_stats +from huggingface_hub import HfApi +from peft import LoHaConfig, LoKrConfig, LoraConfig, OFTConfig, get_peft_model +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +from optimum.habana import GaudiConfig +from optimum.habana.accelerate import GaudiAccelerator +from optimum.habana.accelerate.utils.dataclasses import GaudiDistributedType +from optimum.habana.diffusers import GaudiStableDiffusionPipeline +from optimum.habana.transformers.trainer import _is_peft_model +from optimum.habana.utils import set_seed + + +logger = get_logger(__name__) + +UNET_TARGET_MODULES = [ + "to_q", + "to_k", + "to_v", + "proj", + "proj_in", + "proj_out", + "conv", + "conv1", + "conv2", + "conv_shortcut", + "to_out.0", + "time_emb_proj", + "ff.net.2", +] + +TEXT_ENCODER_TARGET_MODULES = ["fc1", "fc2", "q_proj", "k_proj", "v_proj", "out_proj"] + + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + + +def create_unet_adapter_config(args: argparse.Namespace) -> Union[LoraConfig, LoHaConfig, LoKrConfig, OFTConfig]: + if args.adapter == "full": + raise ValueError("Cannot create unet adapter config for full parameter") + + if args.adapter == "lora": + config = LoraConfig( + r=args.unet_r, + lora_alpha=args.unet_alpha, + target_modules=UNET_TARGET_MODULES, + lora_dropout=args.unet_dropout, + bias=args.unet_bias, + init_lora_weights=True, + ) + elif args.adapter == "loha": + config = LoHaConfig( + r=args.unet_r, + alpha=args.unet_alpha, + target_modules=UNET_TARGET_MODULES, + rank_dropout=args.unet_rank_dropout, + module_dropout=args.unet_module_dropout, + use_effective_conv2d=args.unet_use_effective_conv2d, + init_weights=True, + ) + elif args.adapter == "lokr": + config = LoKrConfig( + r=args.unet_r, + alpha=args.unet_alpha, + target_modules=UNET_TARGET_MODULES, + rank_dropout=args.unet_rank_dropout, + module_dropout=args.unet_module_dropout, + use_effective_conv2d=args.unet_use_effective_conv2d, + decompose_both=args.unet_decompose_both, + decompose_factor=args.unet_decompose_factor, + init_weights=True, + ) + elif args.adapter == "oft": + config = OFTConfig( + r=args.unet_r, + target_modules=UNET_TARGET_MODULES, + module_dropout=args.unet_dropout, + init_weights=True, + coft=args.unet_use_coft, + eps=args.unet_eps, + ) + else: + raise ValueError(f"Unknown adapter type {args.adapter}") + + return config + + +def create_text_encoder_adapter_config( + args: argparse.Namespace, +) -> Union[LoraConfig, LoHaConfig, LoKrConfig, OFTConfig]: + if args.adapter == "full": + raise ValueError("Cannot create text_encoder adapter config for full parameter") + + if args.adapter == "lora": + config = LoraConfig( + r=args.te_r, + lora_alpha=args.te_alpha, + target_modules=TEXT_ENCODER_TARGET_MODULES, + lora_dropout=args.te_dropout, + bias=args.te_bias, + init_lora_weights=True, + ) + elif args.adapter == "loha": + config = LoHaConfig( + r=args.te_r, + alpha=args.te_alpha, + target_modules=TEXT_ENCODER_TARGET_MODULES, + rank_dropout=args.te_rank_dropout, + module_dropout=args.te_module_dropout, + init_weights=True, + ) + elif args.adapter == "lokr": + config = LoKrConfig( + r=args.te_r, + alpha=args.te_alpha, + target_modules=TEXT_ENCODER_TARGET_MODULES, + rank_dropout=args.te_rank_dropout, + module_dropout=args.te_module_dropout, + decompose_both=args.te_decompose_both, + decompose_factor=args.te_decompose_factor, + init_weights=True, + ) + elif args.adapter == "oft": + config = OFTConfig( + r=args.te_r, + target_modules=TEXT_ENCODER_TARGET_MODULES, + module_dropout=args.te_dropout, + init_weights=True, + coft=args.te_use_coft, + eps=args.te_eps, + ) + else: + raise ValueError(f"Unknown adapter type {args.adapter}") + + return config + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run dreambooth validation every X steps. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder") + + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--wandb_key", + type=str, + default=None, + help=("If report to option is set to wandb, api-key for wandb used for login to wandb "), + ) + parser.add_argument( + "--wandb_project_name", + type=str, + default=None, + help=("If report to option is set to wandb, project name in wandb for log tracking "), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "bf16"], + help=( + "Whether to use mixed precision. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "bf16"], + help=("Choose prior generation precision between fp32 and bf16 (bfloat16)."), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--gaudi_config_name", + type=str, + default=None, + help="Local path to the Gaudi configuration file or its name on the Hugging Face Hub.", + ) + parser.add_argument( + "--throughput_warmup_steps", + type=int, + default=0, + help=( + "Number of steps to ignore for throughput calculation. For example, with throughput_warmup_steps=N, the" + " first N steps will not be considered in the calculation of the throughput. This is especially useful in" + " lazy mode." + ), + ) + parser.add_argument( + "--use_hpu_graphs_for_training", + action="store_true", + help="Use HPU graphs for training on HPU.", + ) + parser.add_argument( + "--use_hpu_graphs_for_inference", + action="store_true", + help="Use HPU graphs for inference on HPU.", + ) + + # Adapter arguments + subparsers = parser.add_subparsers(dest="adapter") + + # Dummy subparser to train whole model + subparsers.add_parser("full", help="Train full model without adapters") + + # LoRA adapter + lora = subparsers.add_parser("lora", help="Use LoRA adapter") + lora.add_argument("--unet_r", type=int, default=8, help="LoRA rank for unet") + lora.add_argument("--unet_alpha", type=int, default=8, help="LoRA alpha for unet") + lora.add_argument("--unet_dropout", type=float, default=0.0, help="LoRA dropout probability for unet") + lora.add_argument( + "--unet_bias", + type=str, + default="none", + help="Bias type for LoRA. Can be 'none', 'all' or 'lora_only'", + ) + lora.add_argument( + "--te_r", type=int, default=8, help="LoRA rank for text_encoder, only used if `train_text_encoder` is True" + ) + lora.add_argument( + "--te_alpha", + type=int, + default=8, + help="LoRA alpha for text_encoder, only used if `train_text_encoder` is True", + ) + lora.add_argument( + "--te_dropout", + type=float, + default=0.0, + help="LoRA dropout probability for text_encoder, only used if `train_text_encoder` is True", + ) + lora.add_argument( + "--te_bias", + type=str, + default="none", + help="Bias type for LoRA. Can be 'none', 'all' or 'lora_only', only used if `train_text_encoder` is True", + ) + + # LoHa adapter + loha = subparsers.add_parser("loha", help="Use LoHa adapter") + loha.add_argument("--unet_r", type=int, default=8, help="LoHa rank for unet") + loha.add_argument("--unet_alpha", type=int, default=8, help="LoHa alpha for unet") + loha.add_argument("--unet_rank_dropout", type=float, default=0.0, help="LoHa rank_dropout probability for unet") + loha.add_argument( + "--unet_module_dropout", type=float, default=0.0, help="LoHa module_dropout probability for unet" + ) + loha.add_argument( + "--unet_use_effective_conv2d", + action="store_true", + help="Use parameter effective decomposition in unet for Conv2d 3x3 with ksize > 1", + ) + loha.add_argument( + "--te_r", type=int, default=8, help="LoHa rank for text_encoder, only used if `train_text_encoder` is True" + ) + loha.add_argument( + "--te_alpha", + type=int, + default=8, + help="LoHa alpha for text_encoder, only used if `train_text_encoder` is True", + ) + loha.add_argument( + "--te_rank_dropout", + type=float, + default=0.0, + help="LoHa rank_dropout probability for text_encoder, only used if `train_text_encoder` is True", + ) + loha.add_argument( + "--te_module_dropout", + type=float, + default=0.0, + help="LoHa module_dropout probability for text_encoder, only used if `train_text_encoder` is True", + ) + + # LoKr adapter + lokr = subparsers.add_parser("lokr", help="Use LoKr adapter") + lokr.add_argument("--unet_r", type=int, default=8, help="LoKr rank for unet") + lokr.add_argument("--unet_alpha", type=int, default=8, help="LoKr alpha for unet") + lokr.add_argument("--unet_rank_dropout", type=float, default=0.0, help="LoKr rank_dropout probability for unet") + lokr.add_argument( + "--unet_module_dropout", type=float, default=0.0, help="LoKr module_dropout probability for unet" + ) + lokr.add_argument( + "--unet_use_effective_conv2d", + action="store_true", + help="Use parameter effective decomposition in unet for Conv2d 3x3 with ksize > 1", + ) + lokr.add_argument( + "--unet_decompose_both", action="store_true", help="Decompose left matrix in kronecker product for unet" + ) + lokr.add_argument( + "--unet_decompose_factor", type=int, default=-1, help="Decompose factor in kronecker product for unet" + ) + lokr.add_argument( + "--te_r", type=int, default=8, help="LoKr rank for text_encoder, only used if `train_text_encoder` is True" + ) + lokr.add_argument( + "--te_alpha", + type=int, + default=8, + help="LoKr alpha for text_encoder, only used if `train_text_encoder` is True", + ) + lokr.add_argument( + "--te_rank_dropout", + type=float, + default=0.0, + help="LoKr rank_dropout probability for text_encoder, only used if `train_text_encoder` is True", + ) + lokr.add_argument( + "--te_module_dropout", + type=float, + default=0.0, + help="LoKr module_dropout probability for text_encoder, only used if `train_text_encoder` is True", + ) + lokr.add_argument( + "--te_decompose_both", + action="store_true", + help="Decompose left matrix in kronecker product for text_encoder, only used if `train_text_encoder` is True", + ) + lokr.add_argument( + "--te_decompose_factor", + type=int, + default=-1, + help="Decompose factor in kronecker product for text_encoder, only used if `train_text_encoder` is True", + ) + # oft adapter + oft = subparsers.add_parser("oft", help="Use Oft adapter") + oft.add_argument("--unet_r", type=int, default=8, help="Oft rank for unet") + oft.add_argument("--unet_dropout", type=float, default=0.0, help="Oft dropout probability for unet") + oft.add_argument("--unet_use_coft", action="store_true", help="Using constrained OFT in unet") + oft.add_argument("--unet_eps", type=float, default=0.0, help="The control strength of COFT for unet") + oft.add_argument( + "--te_r", type=int, default=8, help="Oft rank for text_encoder, only used if `train_text_encoder` is True" + ) + oft.add_argument( + "--te_dropout", + type=float, + default=0.0, + help="Oft dropout probability for text_encoder, only used if `train_text_encoder` is True", + ) + oft.add_argument( + "--te_use_coft", + action="store_true", + help="Using constrained OFT in text_encoder, only used if `train_text_encoder` is True", + ) + oft.add_argument( + "--te_eps", + type=float, + default=0.0, + help="The control strength of COFT for text_encoder, only used if `train_text_encoder` is True", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +# Converting Bytes to Megabytes +def b2mb(x): + return int(x / 2**20) + + +# This context manager is used to track the peak memory usage of the process +class TorchTracemalloc: + def __enter__(self): + gc.collect() + mem_stats = memory_stats() + + self.begin = mem_stats["InUse"] + self.process = psutil.Process() + + self.cpu_begin = self.cpu_mem_used() + self.peak_monitoring = True + peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) + peak_monitor_thread.daemon = True + peak_monitor_thread.start() + return self + + def cpu_mem_used(self): + """get resident set size memory for the current process""" + return self.process.memory_info().rss + + def peak_monitor_func(self): + self.cpu_peak = -1 + + while True: + self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) + + # can't sleep or will not catch the peak right (this comment is here on purpose) + # time.sleep(0.001) # 1msec + + if not self.peak_monitoring: + break + + def __exit__(self, *exc): + self.peak_monitoring = False + + gc.collect() + mem_stats = memory_stats() + + self.end = mem_stats["InUse"] + self.peak = mem_stats["MaxInUse"] + self.used = b2mb(self.end - self.begin) + self.peaked = b2mb(self.peak - self.begin) + + self.cpu_end = self.cpu_mem_used() + self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) + self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images and the tokenizes prompts. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + truncation=True, + padding="max_length", + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + return example + + +def collate_fn(examples, with_prior_preservation=False): + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + input_ids = torch.cat(input_ids, dim=0) + + batch = { + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name) + gaudi_config.use_torch_autocast = gaudi_config.use_torch_autocast or args.mixed_precision == "bf16" + accelerator = GaudiAccelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_dir=logging_dir, + force_autocast=gaudi_config.use_torch_autocast, + ) + if args.report_to == "wandb": + import wandb + + wandb.login(key=args.wandb_key) + wandb.init(project=args.wandb_project_name) + # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate + # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. + # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate. + if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: + raise ValueError( + "Gradient accumulation is not supported when training the text encoder in distributed training. " + "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.bfloat16 if accelerator.device.type == "hpu" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + safety_checker=None, + revision=args.revision, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + use_habana=True, + gaudi_config=gaudi_config, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + api = HfApi(token=args.hub_token) + # Create repo (repo_name from args or inferred) + repo_name = args.hub_model_id + if repo_name is None: + repo_name = Path(args.output_dir).absolute().name + repo_id = api.create_repo(repo_name, exist_ok=True).repo_id + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer + if args.tokenizer_name: + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False) + elif args.pretrained_model_name_or_path: + tokenizer = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder class + text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision) + + # Load scheduler and models + noise_scheduler = DDPMScheduler( + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000, + ) # DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder = text_encoder_cls.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision + ) + if args.adapter != "full": + config = create_unet_adapter_config(args) + unet = get_peft_model(unet, config) + unet.print_trainable_parameters() + unet.to(accelerator.device) + vae.requires_grad_(False) + if not args.train_text_encoder: + text_encoder.requires_grad_(False) + elif args.train_text_encoder and args.adapter != "full": + config = create_text_encoder_adapter_config(args) + text_encoder = get_peft_model(text_encoder, config) + text_encoder.print_trainable_parameters() + text_encoder.to(accelerator.device) + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder and not args.adapter != "full": + text_encoder.gradient_checkpointing_enable() + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + if gaudi_config.use_fused_adam: + from habana_frameworks.torch.hpex.optimizers import FusedAdamW + + optimizer_class = FusedAdamW + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters() + ) + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=1, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae and text_encoder to device and cast to weight_dtype + vae.to(accelerator.device, dtype=weight_dtype) + if not args.train_text_encoder: + text_encoder.to(accelerator.device, dtype=weight_dtype) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + def unwrap_model(model, training=False): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + if not training: + return model + else: + if accelerator.distributed_type == GaudiDistributedType.MULTI_HPU: + kwargs = {} + kwargs["gradient_as_bucket_view"] = True + accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) + if args.use_hpu_graphs_for_training: + if _is_peft_model(model): + base_model = model.get_base_model() + htcore.hpu.ModuleCacher()(model=base_model, inplace=True) + else: + htcore.hpu.ModuleCacher()(model=model, inplace=True) + return model + + unwrap_model(model=unet, training=True) + if args.train_text_encoder: + unwrap_model(model=text_encoder, training=True) + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + resume_global_step = global_step * args.gradient_accumulation_steps + first_epoch = resume_global_step // num_update_steps_per_epoch + resume_step = resume_global_step % num_update_steps_per_epoch + + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder.train() + with TorchTracemalloc() as tracemalloc: + for step, batch in enumerate(train_dataloader): + # Skip steps until we reach the resumed step + if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step: + if step % args.gradient_accumulation_steps == 0: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + continue + + with accelerator.accumulate(unet): + # Convert images to latent space + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch["input_ids"])[0] + + # Predict the noise residual + model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + htcore.mark_step() + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet.parameters(), text_encoder.parameters()) + if args.train_text_encoder + else unet.parameters() + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + htcore.mark_step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + if args.report_to == "wandb": + accelerator.print(progress_bar) + global_step += 1 + + # if global_step % args.checkpointing_steps == 0: + # if accelerator.is_main_process: + # save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + # accelerator.save_state(save_path) + # logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if ( + args.validation_prompt is not None + and (step + num_update_steps_per_epoch * epoch) % args.validation_steps == 0 + ): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + safety_checker=None, + revision=args.revision, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + use_habana=True, + gaudi_config=gaudi_config, + ) + # set `keep_fp32_wrapper` to True because we do not want to remove + # mixed precision hooks while we are still training + pipeline.unet = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) + pipeline.text_encoder = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # Set evaliation mode + pipeline.unet.eval() + pipeline.text_encoder.eval() + + # run inference + if args.seed is not None: + if accelerator.device == torch.device("hpu"): + # torch.Generator() is unsupported on HPU + generator = set_seed(args.seed) + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + else: + generator = None + images = [] + for _ in range(args.num_validation_images): + image = pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + import wandb + + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + # Set evaliation mode + pipeline.unet.train() + if args.train_text_encoder: + pipeline.text_encoder.train() + + del pipeline + + if global_step >= args.max_train_steps: + break + # Printing the HPU memory usage details such as allocated memory, peak memory, and total memory usage + accelerator.print(f"HPU Memory before entering the train : {b2mb(tracemalloc.begin)}") + accelerator.print(f"HPU Memory consumed at the end of the train (end-begin): {tracemalloc.used}") + accelerator.print(f"HPU Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}") + accelerator.print( + f"HPU Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}" + ) + + accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}") + accelerator.print(f"CPU Memory consumed at the end of the train (end-begin): {tracemalloc.cpu_used}") + accelerator.print(f"CPU Peak Memory consumed during the train (max-begin): {tracemalloc.cpu_peaked}") + accelerator.print( + f"CPU Total Peak Memory consumed during the train (max): {tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)}" + ) + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + if args.adapter != "full": + unwarpped_unet = unwrap_model(unet) + unwarpped_unet.save_pretrained( + os.path.join(args.output_dir, "unet"), state_dict=accelerator.get_state_dict(unet) + ) + if args.train_text_encoder: + unwarpped_text_encoder = unwrap_model(text_encoder) + unwarpped_text_encoder.save_pretrained( + os.path.join(args.output_dir, "text_encoder"), + state_dict=accelerator.get_state_dict(text_encoder), + ) + else: + pipeline = GaudiStableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=unwrap_model(unet), + text_encoder=unwrap_model(text_encoder), + revision=args.revision, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + use_habana=True, + gaudi_config=gaudi_config, + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + api.upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + run_as_future=True, + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py b/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py new file mode 100644 index 0000000000..ea34c50773 --- /dev/null +++ b/examples/stable-diffusion/training/train_dreambooth_lora_sdxl.py @@ -0,0 +1,1768 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +""" +Training script for LORA DreamBooth to Text-to-Image Diffusion Models +Adapted from the following source: +https://github.com/huggingface/diffusers/blob/v0.26.3/examples/dreambooth/train_dreambooth_lora_sdxl.py +""" + +import argparse +import gc +import itertools +import logging +import math +import os +import shutil +import warnings +from pathlib import Path + +import diffusers +import habana_frameworks.torch.core as htcore +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate.logging import get_logger +from accelerate.utils import DistributedDataParallelKwargs +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + DPMSolverMultistepScheduler, + UNet2DConditionModel, +) +from diffusers.loaders import LoraLoaderMixin +from diffusers.optimization import get_scheduler +from diffusers.training_utils import _set_state_dict_into_text_encoder, compute_snr +from diffusers.utils import ( + check_min_version, + convert_state_dict_to_diffusers, + convert_unet_state_dict_to_peft, + is_wandb_available, +) +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.torch_utils import is_compiled_module +from huggingface_hub import create_repo, upload_folder +from huggingface_hub.utils import insecure_hashlib +from packaging import version +from peft import LoraConfig, set_peft_model_state_dict +from peft.utils import get_peft_model_state_dict +from PIL import Image +from PIL.ImageOps import exif_transpose +from torch.utils.data import Dataset +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +from optimum.habana import GaudiConfig +from optimum.habana.accelerate import GaudiAccelerator +from optimum.habana.accelerate.utils.dataclasses import GaudiDistributedType +from optimum.habana.diffusers import GaudiStableDiffusionXLPipeline +from optimum.habana.transformers.trainer import _is_peft_model +from optimum.habana.utils import set_seed + + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.26.0") + +logger = get_logger(__name__) + + +def save_model_card( + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + instance_prompt=str, + validation_prompt=str, + repo_folder=None, + vae_path=None, +): + img_str = "widget:\n" if images else "" + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f""" + - text: '{validation_prompt if validation_prompt else ' ' }' + output: + url: + "image_{i}.png" + """ + + yaml = f""" +--- +tags: +- stable-diffusion-xl +- stable-diffusion-xl-diffusers +- text-to-image +- diffusers +- lora +- template:sd-lora +{img_str} +base_model: {base_model} +instance_prompt: {instance_prompt} +license: openrail++ +--- + """ + + model_card = f""" +# SDXL LoRA DreamBooth - {repo_id} + + + +## Model description + +These are {repo_id} LoRA adaption weights for {base_model}. + +The weights were trained using [DreamBooth](https://dreambooth.github.io/). + +LoRA for the text encoder was enabled: {train_text_encoder}. + +Special VAE used for training: {vae_path}. + +## Trigger words + +You should use {instance_prompt} to trigger the image generation. + +## Download model + +Weights for this model are available in Safetensors format. + +[Download]({repo_id}/tree/main) them in the Files & versions tab. + +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + help=("A folder containing the training data. "), + ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") + + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=False, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + required=True, + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided instance images.", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + help="A prompt that is used during validation to verify that the model is learning.", + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images that should be generated during validation with `validation_prompt`.", + ) + parser.add_argument( + "--validation_epochs", + type=int, + default=50, + help=( + "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`." + ), + ) + parser.add_argument( + "--with_prior_preservation", + default=False, + action="store_true", + help="Flag to add prior preservation loss.", + ) + parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.") + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="lora-dreambooth-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=1024, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--train_text_encoder", + action="store_true", + help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final" + " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming" + " training using `--resume_from_checkpoint`." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + + parser.add_argument( + "--optimizer", + type=str, + default="AdamW", + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), + ) + + parser.add_argument( + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " + "uses the value of square root of beta2. Ignored if optimizer is adamW", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" + ) + + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) + + parser.add_argument( + "--prodigy_use_bias_correction", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", + ) + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "bf16"], + help=( + "Whether to use mixed precision. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--prior_generation_precision", + type=str, + default=None, + choices=["no", "fp32", "bf16"], + help=("Choose prior generation precision between fp32 and bf16 (bfloat16)."), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--rank", + type=int, + default=4, + help=("The dimension of the LoRA update matrices."), + ) + parser.add_argument( + "--gaudi_config_name", + type=str, + default=None, + help="Local path to the Gaudi configuration file or its name on the Hugging Face Hub.", + ) + parser.add_argument( + "--use_hpu_graphs_for_training", + action="store_true", + help="Use HPU graphs for training on HPU.", + ) + parser.add_argument( + "--use_hpu_graphs_for_inference", + action="store_true", + help="Use HPU graphs for inference on HPU.", + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.instance_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.instance_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.with_prior_preservation: + if args.class_data_dir is None: + raise ValueError("You must specify a data directory for class images.") + if args.class_prompt is None: + raise ValueError("You must specify prompt for class images.") + else: + # logger is not available yet + if args.class_data_dir is not None: + warnings.warn("You need not use --class_data_dir without --with_prior_preservation.") + if args.class_prompt is not None: + warnings.warn("You need not use --class_prompt without --with_prior_preservation.") + + return args + + +class DreamBoothDataset(Dataset): + """ + A dataset to prepare the instance and class images with the prompts for fine-tuning the model. + It pre-processes the images. + """ + + def __init__( + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset + if args.dataset_name is not None: + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) + else: + self.instance_data_root = Path(instance_data_root) + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None + + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) + self.num_instance_images = len(self.instance_images) + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + self.class_data_root.mkdir(parents=True, exist_ok=True) + self.class_images_path = list(self.class_data_root.iterdir()) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + else: + self.class_data_root = None + + self.image_transforms = transforms.Compose( + [ + transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def __len__(self): + return self._length + + def __getitem__(self, index): + example = {} + instance_image = self.instance_images[index % self.num_instance_images] + instance_image = exif_transpose(instance_image) + + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.image_transforms(instance_image) + + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = caption + else: + example["instance_prompt"] = self.instance_prompt + + else: # costum prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + class_image = exif_transpose(class_image) + + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt + + return example + + +def collate_fn(examples, with_prior_preservation=False): + pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] + + # Concat class and instance examples for prior preservation. + # We do this to avoid doing two forward passes. + if with_prior_preservation: + pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + batch = {"pixel_values": pixel_values, "prompts": prompts} + return batch + + +class PromptDataset(Dataset): + "A simple dataset to prepare the prompts to generate class images on multiple GPUs." + + def __init__(self, prompt, num_samples): + self.prompt = prompt + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + example = {} + example["prompt"] = self.prompt + example["index"] = index + return example + + +def tokenize_prompt(tokenizer, prompt): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + return text_input_ids + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None): + prompt_embeds_list = [] + + for i, text_encoder in enumerate(text_encoders): + if tokenizers is not None: + tokenizer = tokenizers[i] + text_input_ids = tokenize_prompt(tokenizer, prompt) + else: + assert text_input_ids_list is not None + text_input_ids = text_input_ids_list[i] + + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), output_hidden_states=True, return_dict=False + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds[-1][-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + logging_dir = Path(args.output_dir, args.logging_dir) + gaudi_config = GaudiConfig.from_pretrained(args.gaudi_config_name) + gaudi_config.use_torch_autocast = gaudi_config.use_torch_autocast or args.mixed_precision == "bf16" + accelerator = GaudiAccelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_dir=logging_dir, + force_autocast=gaudi_config.use_torch_autocast, + ) + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + import wandb + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Generate class images if prior preservation is enabled. + if args.with_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + torch_dtype = torch.bfloat16 if accelerator.device.type == "hpu" else torch.float32 + if args.prior_generation_precision == "fp32": + torch_dtype = torch.float32 + elif args.prior_generation_precision == "bf16": + torch_dtype = torch.bfloat16 + pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + torch_dtype=torch_dtype, + revision=args.revision, + variant=args.variant, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + use_habana=True, + gaudi_config=gaudi_config, + ) + pipeline.set_progress_bar_config(disable=True) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + + sample_dataset = PromptDataset(args.class_prompt, num_new_images) + sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) + + sample_dataloader = accelerator.prepare(sample_dataloader) + pipeline.to(accelerator.device) + + for example in tqdm( + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + images = pipeline(example["prompt"]).images + + for i, image in enumerate(images): + hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest() + image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg" + image.save(image_filename) + + del pipeline + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + use_fast=False, + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + use_fast=False, + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant + ) + + # We only train the additional adapter LoRA layers + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + unet.requires_grad_(False) + + # For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + unet.to(accelerator.device, dtype=weight_dtype) + + # The VAE is always in float32 to avoid NaN losses. + vae.to(accelerator.device, dtype=torch.float32) + + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, " + "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + if args.train_text_encoder: + text_encoder_one.gradient_checkpointing_enable() + text_encoder_two.gradient_checkpointing_enable() + + # now we will add new LoRA weights to the attention layers + unet_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["to_k", "to_q", "to_v", "to_out.0"], + ) + unet.add_adapter(unet_lora_config) + + # The text encoder comes from 🤗 transformers, so we cannot directly modify it. + # So, instead, we monkey-patch the forward calls of its attention-blocks. + if args.train_text_encoder: + text_lora_config = LoraConfig( + r=args.rank, + lora_alpha=args.rank, + init_lora_weights="gaussian", + target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], + ) + text_encoder_one.add_adapter(text_lora_config) + text_encoder_two.add_adapter(text_lora_config) + + def unwrap_model(model, training=False): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + if not training: + return model + else: + if accelerator.distributed_type == GaudiDistributedType.MULTI_HPU: + kwargs = {} + kwargs["gradient_as_bucket_view"] = True + accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) + if args.use_hpu_graphs_for_training: + if _is_peft_model(model): + base_model = model.get_base_model() + htcore.hpu.ModuleCacher()(model=base_model, inplace=True) + else: + htcore.hpu.ModuleCacher()(model=model, inplace=True) + return model + + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_one_lora_layers_to_save = None + text_encoder_two_lora_layers_to_save = None + + for model in models: + if isinstance(model, type(unwrap_model(unet))): + unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model)) + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) + elif isinstance(model, type(unwrap_model(text_encoder_two))): + text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers( + get_peft_model_state_dict(model) + ) + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + GaudiStableDiffusionXLPipeline.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_one_lora_layers_to_save, + text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + unet_ = None + text_encoder_one_ = None + text_encoder_two_ = None + + while len(models) > 0: + model = models.pop() + + if isinstance(model, type(unwrap_model(unet))): + unet_ = model + elif isinstance(model, type(unwrap_model(text_encoder_one))): + text_encoder_one_ = model + elif isinstance(model, type(unwrap_model(text_encoder_two))): + text_encoder_two_ = model + else: + raise ValueError(f"unexpected save model: {model.__class__}") + + lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) + + unet_state_dict = {f'{k.replace("unet.", "")}': v for k, v in lora_state_dict.items() if k.startswith("unet.")} + unet_state_dict = convert_unet_state_dict_to_peft(unet_state_dict) + incompatible_keys = set_peft_model_state_dict(unet_, unet_state_dict, adapter_name="default") + if incompatible_keys is not None: + # check only for unexpected keys + unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None) + if unexpected_keys: + logger.warning( + f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " + f" {unexpected_keys}. " + ) + + if args.train_text_encoder: + # Do we need to call `scale_lora_layers()` here? + _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_one_) + + _set_state_dict_into_text_encoder( + lora_state_dict, prefix="text_encoder_2.", text_encoder=text_encoder_two_ + ) + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + unet_lora_parameters = list(filter(lambda p: p.requires_grad, unet.parameters())) + + if args.train_text_encoder: + text_lora_parameters_one = list(filter(lambda p: p.requires_grad, text_encoder_one.parameters())) + text_lora_parameters_two = list(filter(lambda p: p.requires_grad, text_encoder_two.parameters())) + + # Optimization parameters + unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} + if args.train_text_encoder: + # different learning rate for text encoder and unet + text_lora_parameters_one_with_lr = { + "params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + text_lora_parameters_two_with_lr = { + "params": text_lora_parameters_two, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [ + unet_lora_parameters_with_lr, + text_lora_parameters_one_with_lr, + text_lora_parameters_two_with_lr, + ] + else: + params_to_optimize = [unet_lora_parameters_with_lr] + + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warn( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warn( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + elif gaudi_config.use_fused_adam: + from habana_frameworks.torch.hpex.optimizers import FusedAdamW + + optimizer_class = FusedAdamW + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + if args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warn( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warn( + f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + repeats=args.repeats, + center_crop=args.center_crop, + ) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, + ) + + # Computes additional embeddings/ids required by the SDXL UNet. + # regular text embeddings (when `train_text_encoder` is not True) + # pooled text embeddings + # time ids + + def compute_time_ids(): + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + original_size = (args.resolution, args.resolution) + target_size = (args.resolution, args.resolution) + crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) + return add_time_ids + + if not args.train_text_encoder: + tokenizers = [tokenizer_one, tokenizer_two] + text_encoders = [text_encoder_one, text_encoder_two] + + def compute_text_embeddings(prompt, text_encoders, tokenizers): + with torch.no_grad(): + prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + return prompt_embeds, pooled_prompt_embeds + + # Handle instance prompt. + instance_time_ids = compute_time_ids() + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( + args.instance_prompt, text_encoders, tokenizers + ) + + # Handle class prompt for prior-preservation. + if args.with_prior_preservation: + class_time_ids = compute_time_ids() + if not args.train_text_encoder: + class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( + args.class_prompt, text_encoders, tokenizers + ) + + # Clear the memory here + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: + del tokenizers, text_encoders + gc.collect() + + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't + # have to pass them to the dataloader. + add_time_ids = instance_time_ids + if args.with_prior_preservation: + add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0) + + if not train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds = instance_prompt_hidden_states + unet_add_text_embeds = instance_pooled_prompt_embeds + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) + # if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the + # batch prompts on all training steps + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) + tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) + class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + if args.train_text_encoder: + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler + ) + else: + unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + unet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth-lora-sd-xl", config=vars(args)) + + unwrap_model(model=unet, training=True) + if args.train_text_encoder: + unwrap_model(model=text_encoder_one, training=True) + unwrap_model(model=text_encoder_two, training=True) + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the mos recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + for epoch in range(first_epoch, args.num_train_epochs): + unet.train() + if args.train_text_encoder: + text_encoder_one.train() + text_encoder_two.train() + + # set top parameter requires_grad = True for gradient checkpointing works + accelerator.unwrap_model(text_encoder_one).text_model.embeddings.requires_grad_(True) + accelerator.unwrap_model(text_encoder_two).text_model.embeddings.requires_grad_(True) + + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(unet): + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + prompts = batch["prompts"] + + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds, unet_add_text_embeds = compute_text_embeddings( + prompts, text_encoders, tokenizers + ) + else: + tokens_one = tokenize_prompt(tokenizer_one, prompts) + tokens_two = tokenize_prompt(tokenizer_two, prompts) + + # Convert images to latent space + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = model_input * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + model_input = model_input.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device + ) + timesteps = timesteps.long() + + # Add noise to the model input according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + + # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. + if not train_dataset.custom_instance_prompts: + elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz + elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz + else: + elems_to_repeat_text_embeds = 1 + elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz + + # Predict the noise residual + if not args.train_text_encoder: + unet_added_conditions = { + "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1), + "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), + } + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) + model_pred = unet( + noisy_model_input, + timesteps, + prompt_embeds_input, + added_cond_kwargs=unet_added_conditions, + return_dict=False, + )[0] + else: + unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders=[text_encoder_one, text_encoder_two], + tokenizers=None, + prompt=None, + text_input_ids_list=[tokens_one, tokens_two], + ) + unet_added_conditions.update( + {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} + ) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) + model_pred = unet( + noisy_model_input, + timesteps, + prompt_embeds_input, + added_cond_kwargs=unet_added_conditions, + return_dict=False, + )[0] + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + if args.with_prior_preservation: + # Chunk the noise and model_pred into two parts and compute the loss on each part separately. + model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) + target, target_prior = torch.chunk(target, 2, dim=0) + + # Compute prior loss + prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(noise_scheduler, timesteps) + base_weight = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + if args.with_prior_preservation: + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + + accelerator.backward(loss) + htcore.mark_step() + if accelerator.sync_gradients: + params_to_clip = ( + itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) + if args.train_text_encoder + else unet_lora_parameters + ) + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=True) + htcore.mark_step() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + if accelerator.is_main_process: + if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + text_encoder=accelerator.unwrap_model(text_encoder_one), + text_encoder_2=accelerator.unwrap_model(text_encoder_two), + unet=accelerator.unwrap_model(unet), + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + use_habana=True, + gaudi_config=gaudi_config, + ) + pipeline.text_encoder.eval() + pipeline.text_encoder_2.eval() + pipeline.unet.eval() + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, **scheduler_args + ) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + if args.seed is not None: + if accelerator.device == torch.device("hpu"): + # torch.Generator() is unsupported on HPU + generator = set_seed(args.seed) + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + else: + generator = None + pipeline_args = {"prompt": args.validation_prompt} + + images = [ + pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + pipeline.unet.train() + if args.train_text_encoder: + pipeline.text_encoder.train() + pipeline.text_encoder_2.train() + del pipeline + + # Save the lora layers + accelerator.wait_for_everyone() + if accelerator.is_main_process: + unet = unwrap_model(unet) + unet = unet.to(torch.float32) + unet_lora_layers = convert_state_dict_to_diffusers(get_peft_model_state_dict(unet)) + + if args.train_text_encoder: + text_encoder_one = unwrap_model(text_encoder_one) + text_encoder_lora_layers = convert_state_dict_to_diffusers( + get_peft_model_state_dict(text_encoder_one.to(torch.float32)) + ) + text_encoder_two = unwrap_model(text_encoder_two) + text_encoder_2_lora_layers = convert_state_dict_to_diffusers( + get_peft_model_state_dict(text_encoder_two.to(torch.float32)) + ) + else: + text_encoder_lora_layers = None + text_encoder_2_lora_layers = None + + GaudiStableDiffusionXLPipeline.save_lora_weights( + save_directory=args.output_dir, + unet_lora_layers=unet_lora_layers, + text_encoder_lora_layers=text_encoder_lora_layers, + text_encoder_2_lora_layers=text_encoder_2_lora_layers, + ) + # Final inference + # Load previous pipeline + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + use_habana=True, + gaudi_config=gaudi_config, + ) + + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + # load attention processors + pipeline.load_lora_weights(args.output_dir) + + # run inference + images = [] + if args.validation_prompt and args.num_validation_images > 0: + pipeline = pipeline.to(accelerator.device) + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + if args.push_to_hub: + save_model_card( + repo_id, + images=images, + base_model=args.pretrained_model_name_or_path, + train_text_encoder=args.train_text_encoder, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, + repo_folder=args.output_dir, + vae_path=args.pretrained_vae_model_name_or_path, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/optimum/habana/diffusers/pipelines/pipeline_utils.py b/optimum/habana/diffusers/pipelines/pipeline_utils.py index 14e074eef3..c2ffdb9cc5 100644 --- a/optimum/habana/diffusers/pipelines/pipeline_utils.py +++ b/optimum/habana/diffusers/pipelines/pipeline_utils.py @@ -19,7 +19,7 @@ import inspect import os import sys -from typing import Optional, Union +from typing import Callable, Dict, Optional, Union import torch from diffusers.pipelines import DiffusionPipeline @@ -28,6 +28,7 @@ from diffusers.utils.torch_utils import is_compiled_module from huggingface_hub import create_repo +from optimum.habana.utils import to_device_dtype from optimum.utils import logging from ...transformers.gaudi_configuration import GaudiConfig @@ -359,3 +360,33 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P pretrained_model_name_or_path, **kwargs, ) + + @classmethod + def save_lora_weights( + cls, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + # Move the state dict from HPU to CPU before saving + if unet_lora_layers: + unet_lora_layers = to_device_dtype(unet_lora_layers, target_device=torch.device("cpu")) + if text_encoder_lora_layers: + text_encoder_lora_layers = to_device_dtype(text_encoder_lora_layers, target_device=torch.device("cpu")) + if text_encoder_2_lora_layers: + text_encoder_2_lora_layers = to_device_dtype(text_encoder_2_lora_layers, target_device=torch.device("cpu")) + return super().save_lora_weights( + save_directory, + unet_lora_layers, + text_encoder_lora_layers, + text_encoder_2_lora_layers, + is_main_process, + weight_name, + save_function, + safe_serialization, + ) diff --git a/tests/resource/image-captioning-example.png b/tests/resource/img/image-captioning-example.png similarity index 100% rename from tests/resource/image-captioning-example.png rename to tests/resource/img/image-captioning-example.png diff --git a/tests/test_diffusers.py b/tests/test_diffusers.py index a1c31f2181..e9628620e7 100755 --- a/tests/test_diffusers.py +++ b/tests/test_diffusers.py @@ -22,11 +22,13 @@ import tempfile from io import BytesIO from pathlib import Path +from typing import Union from unittest import TestCase, skipUnless import numpy as np import pytest import requests +import safetensors import torch from diffusers import ( AutoencoderKL, @@ -1896,6 +1898,171 @@ def test_train_controlnet(self): self.assertEqual(image.shape, (512, 512, 3)) +def install_requirements(requirements_filename: Union[str, os.PathLike]): + """ + Installs the necessary requirements to run the example if the provided file exists, otherwise does nothing. + """ + + if not Path(requirements_filename).exists(): + return + + cmd_line = f"pip install -r {requirements_filename}".split() + p = subprocess.Popen(cmd_line) + return_code = p.wait() + assert return_code == 0 + + +class DreamBooth(TestCase): + def _test_dreambooth(self, extra_config, train_text_encoder=False): + path_to_script = ( + Path(os.path.dirname(__file__)).parent + / "examples" + / "stable-diffusion" + / "training" + / "train_dreambooth.py" + ) + install_requirements(path_to_script.parent / "requirements.txt") + instance_prompt = "soccer player kicking a ball" + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + python3 + {path_to_script} + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe + --instance_data_dir {Path(os.path.dirname(__file__))/'resource/img'} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --train_text_encoder + --max_train_steps 1 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --gaudi_config_name Habana/stable-diffusion + --output_dir {tmpdir} + """.split() + + test_args.append("--instance_prompt") + test_args.append(instance_prompt) + if "oft" not in extra_config: + test_args.append("--use_hpu_graphs_for_training") + test_args.append("--use_hpu_graphs_for_inference") + if train_text_encoder: + test_args.append("--train_text_encoder") + test_args.append(extra_config) + p = subprocess.Popen(test_args) + return_code = p.wait() + + # Ensure the run finished without any issue + self.assertEqual(return_code, 0) + # save_pretrained smoke test + if "full" in extra_config: + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "diffusion_pytorch_model.safetensors"))) + if train_text_encoder: + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "text_encoder", "model.safetensors"))) + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json"))) + else: + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "unet", "adapter_model.safetensors"))) + if train_text_encoder: + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "text_encoder", "adapter_model.safetensors"))) + + def test_dreambooth_full(self): + self._test_dreambooth("full") + + def test_dreambooth_full_with_text_encoder(self): + self._test_dreambooth("full", train_text_encoder=True) + + def test_dreambooth_lora(self): + self._test_dreambooth("lora") + + def test_dreambooth_lora_with_text_encoder(self): + self._test_dreambooth("lora", train_text_encoder=True) + + def test_dreambooth_lokr(self): + self._test_dreambooth("lokr") + + def test_dreambooth_lokr_with_text_encoder(self): + self._test_dreambooth("lokr", train_text_encoder=True) + + def test_dreambooth_loha(self): + self._test_dreambooth("loha") + + def test_dreambooth_loha_with_text_encoder(self): + self._test_dreambooth("loha", train_text_encoder=True) + + def test_dreambooth_oft(self): + self._test_dreambooth("oft") + + def test_dreambooth_oft_with_text_encoder(self): + self._test_dreambooth("oft", train_text_encoder=True) + + +class DreamBoothLoRASDXL(TestCase): + def _test_dreambooth_lora_sdxl(self, train_text_encoder=False): + path_to_script = ( + Path(os.path.dirname(__file__)).parent + / "examples" + / "stable-diffusion" + / "training" + / "train_dreambooth_lora_sdxl.py" + ) + install_requirements(path_to_script.parent / "requirements.txt") + + instance_prompt = "soccer player kicking a ball" + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + python3 + {path_to_script} + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --instance_data_dir {Path(os.path.dirname(__file__))/'resource/img'} + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 1 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --gaudi_config_name Habana/stable-diffusion + --use_hpu_graphs_for_training + --use_hpu_graphs_for_inference + --output_dir {tmpdir} + """.split() + if train_text_encoder: + test_args.append("--train_text_encoder") + test_args.append("--instance_prompt") + test_args.append(instance_prompt) + p = subprocess.Popen(test_args) + return_code = p.wait() + + # Ensure the run finished without any issue + self.assertEqual(return_code, 0) + # save_pretrained smoke test + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) + + # make sure the state_dict has the correct naming in the parameters. + lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) + is_lora = all("lora" in k for k in lora_state_dict.keys()) + self.assertTrue(is_lora) + + # when not training the text encoder, all the parameters in the state dict should start + # with `"unet"` in their names. + if train_text_encoder: + starts_with_unet = all( + k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") + for k in lora_state_dict.keys() + ) + else: + starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys()) + self.assertTrue(starts_with_unet) + + def test_dreambooth_lora_sdxl_with_text_encoder(self): + self._test_dreambooth_lora_sdxl(train_text_encoder=True) + + def test_dreambooth_lora_sdxl(self): + self._test_dreambooth_lora_sdxl(train_text_encoder=False) + + class GaudiStableVideoDiffusionPipelineTester(TestCase): """ Tests the StableVideoDiffusionPipeline for Gaudi. diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 5a56d0a12e..a73d088bdd 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os from unittest import TestCase import torch @@ -32,7 +33,7 @@ def _test_image_to_text(self, model, expected_result): "max_new_tokens": 128, "ignore_eos": False, } - image = "./tests/resource/image-captioning-example.png" + image = os.path.dirname(__file__) + "/resource/img/image-captioning-example.png" for model_dtype in MODEL_DTYPE_LIST: generator = pipeline( "image-to-text",