Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/dreambooth/train_dreambooth_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,11 +843,15 @@ def load_model_hook(models, input_dir):

lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)

text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)

text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)

accelerator.register_save_state_pre_hook(save_model_hook)
Expand Down
71 changes: 71 additions & 0 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,77 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self):
)
self.assertTrue(starts_with_unet)

def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"

with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path {pipeline_path}
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--checkpointing_steps=2
--checkpoints_total_limit=2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)

pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe("a prompt", num_inference_steps=2)

# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)

def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"

with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path {pipeline_path}
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--checkpointing_steps=2
--checkpoints_total_limit=2
--train_text_encoder
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()

run_command(self._launch_args + test_args)

pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe("a prompt", num_inference_steps=2)

# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)

def test_custom_diffusion(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
Expand Down