diff --git a/examples/test_examples.py b/examples/test_examples.py index c12154a0d572..c06c9417d594 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -828,6 +828,87 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self): {"checkpoint-4", "checkpoint-6"}, ) + def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self): + prompt = "a prompt" + pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/text_to_image/train_text_to_image_lora_sdxl.py + --pretrained_model_name_or_path {pipeline_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained(pipeline_path) + pipe.load_lora_weights(tmpdir) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + + def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self): + prompt = "a prompt" + pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" + + with tempfile.TemporaryDirectory() as tmpdir: + # Run training script with checkpointing + # max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2 + # Should create checkpoints at steps 2, 4, 6 + # with checkpoint at step 2 deleted + + initial_run_args = f""" + examples/text_to_image/train_text_to_image_lora_sdxl.py + --pretrained_model_name_or_path {pipeline_path} + --dataset_name hf-internal-testing/dummy_image_text_data + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 7 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --train_text_encoder + --lr_warmup_steps 0 + --output_dir {tmpdir} + --checkpointing_steps=2 + --checkpoints_total_limit=2 + """.split() + + run_command(self._launch_args + initial_run_args) + + pipe = DiffusionPipeline.from_pretrained(pipeline_path) + pipe.load_lora_weights(tmpdir) + pipe(prompt, num_inference_steps=2) + + # check checkpoint directories exist + self.assertEqual( + {x for x in os.listdir(tmpdir) if "checkpoint" in x}, + # checkpoint-2 should have been deleted + {"checkpoint-4", "checkpoint-6"}, + ) + def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self): pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe" prompt = "a prompt" diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index d7c2d07be431..fe8bdc594b38 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -396,16 +396,6 @@ def parse_args(input_args=None): " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) - parser.add_argument( - "--prior_generation_precision", - type=str, - default=None, - choices=["no", "fp32", "fp16", "bf16"], - help=( - "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" - " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32." - ), - ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." @@ -724,11 +714,15 @@ def load_model_hook(models, input_dir): lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir) LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_) + + text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k} LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ + text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_ ) + + text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k} LoraLoaderMixin.load_lora_into_text_encoder( - lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ + text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_ ) accelerator.register_save_state_pre_hook(save_model_hook) @@ -1002,9 +996,12 @@ def collate_fn(examples): continue with accelerator.accumulate(unet): - pixel_values = batch["pixel_values"].to(dtype=weight_dtype) - # Convert images to latent space + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + else: + pixel_values = batch["pixel_values"] + model_input = vae.encode(pixel_values).latent_dist.sample() model_input = model_input * vae.config.scaling_factor if args.pretrained_vae_model_name_or_path is None: @@ -1147,13 +1144,6 @@ def compute_time_ids(original_size, crops_coords_top_left): f" {args.validation_prompt}." ) # create pipeline - if not args.train_text_encoder: - text_encoder_one = text_encoder_cls_one.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision - ) - text_encoder_two = text_encoder_cls_two.from_pretrained( - args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision - ) pipeline = StableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, vae=vae, diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 7f06da81ba38..c2fe98993d00 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -664,6 +664,7 @@ def test_load_lora_locally(self): unet_lora_layers=lora_components["unet_lora_layers"], text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"], text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"], + safe_serialization=False, ) self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))) sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))