Skip to content

Commit 60f9f39

Browse files
DN6geniuspatrick
authored andcommitted
fix(training): lr scheduler doesn't work properly in distributed scenarios
1 parent a2ecce2 commit 60f9f39

File tree

36 files changed

+193
-109
lines changed

36 files changed

+193
-109
lines changed

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -697,17 +697,22 @@ def collate_fn(examples):
697697
)
698698

699699
# Scheduler and math around the number of training steps.
700-
overrode_max_train_steps = False
701-
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
700+
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
701+
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
702702
if args.max_train_steps is None:
703-
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
704-
overrode_max_train_steps = True
703+
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
704+
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
705+
num_training_steps_for_scheduler = (
706+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
707+
)
708+
else:
709+
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
705710

706711
lr_scheduler = get_scheduler(
707712
args.lr_scheduler,
708713
optimizer=optimizer,
709-
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
710-
num_training_steps=args.max_train_steps * accelerator.num_processes,
714+
num_warmup_steps=num_warmup_steps_for_scheduler,
715+
num_training_steps=num_training_steps_for_scheduler,
711716
)
712717

713718
# Prepare everything with our `accelerator`.
@@ -717,8 +722,14 @@ def collate_fn(examples):
717722

718723
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
719724
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
720-
if overrode_max_train_steps:
725+
if args.max_train_steps is None:
721726
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
727+
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
728+
logger.warning(
729+
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
730+
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
731+
f"This inconsistency may result in the learning rate scheduler not functioning properly."
732+
)
722733
# Afterwards we recalculate our number of training epochs
723734
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
724735

src/diffusers/pipelines/animatediff/pipeline_animatediff.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -316,9 +316,10 @@ def encode_prompt(
316316
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
317317
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
318318

319-
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
320-
# Retrieve the original scale by scaling back the LoRA layers
321-
unscale_lora_layers(self.text_encoder, lora_scale)
319+
if self.text_encoder is not None:
320+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
321+
# Retrieve the original scale by scaling back the LoRA layers
322+
unscale_lora_layers(self.text_encoder, lora_scale)
322323

323324
return prompt_embeds, negative_prompt_embeds
324325

src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -420,9 +420,10 @@ def encode_prompt(
420420
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
421421
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
422422

423-
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
424-
# Retrieve the original scale by scaling back the LoRA layers
425-
unscale_lora_layers(self.text_encoder, lora_scale)
423+
if self.text_encoder is not None:
424+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
425+
# Retrieve the original scale by scaling back the LoRA layers
426+
unscale_lora_layers(self.text_encoder, lora_scale)
426427

427428
return prompt_embeds, negative_prompt_embeds
428429

src/diffusers/pipelines/controlnet/pipeline_controlnet.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -463,9 +463,10 @@ def encode_prompt(
463463
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
464464
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
465465

466-
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
467-
# Retrieve the original scale by scaling back the LoRA layers
468-
unscale_lora_layers(self.text_encoder, lora_scale)
466+
if self.text_encoder is not None:
467+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
468+
# Retrieve the original scale by scaling back the LoRA layers
469+
unscale_lora_layers(self.text_encoder, lora_scale)
469470

470471
return prompt_embeds, negative_prompt_embeds
471472

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,9 +441,10 @@ def encode_prompt(
441441
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
442442
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
443443

444-
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
445-
# Retrieve the original scale by scaling back the LoRA layers
446-
unscale_lora_layers(self.text_encoder, lora_scale)
444+
if self.text_encoder is not None:
445+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
446+
# Retrieve the original scale by scaling back the LoRA layers
447+
unscale_lora_layers(self.text_encoder, lora_scale)
447448

448449
return prompt_embeds, negative_prompt_embeds
449450

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -566,9 +566,10 @@ def encode_prompt(
566566
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
567567
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
568568

569-
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
570-
# Retrieve the original scale by scaling back the LoRA layers
571-
unscale_lora_layers(self.text_encoder, lora_scale)
569+
if self.text_encoder is not None:
570+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
571+
# Retrieve the original scale by scaling back the LoRA layers
572+
unscale_lora_layers(self.text_encoder, lora_scale)
572573

573574
return prompt_embeds, negative_prompt_embeds
574575

src/diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,9 +390,10 @@ def encode_prompt(
390390
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
391391
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
392392

393-
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
394-
# Retrieve the original scale by scaling back the LoRA layers
395-
unscale_lora_layers(self.text_encoder, lora_scale)
393+
if self.text_encoder is not None:
394+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
395+
# Retrieve the original scale by scaling back the LoRA layers
396+
unscale_lora_layers(self.text_encoder, lora_scale)
396397

397398
return prompt_embeds, negative_prompt_embeds
398399

src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,9 +456,10 @@ def encode_prompt(
456456
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
457457
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
458458

459-
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
460-
# Retrieve the original scale by scaling back the LoRA layers
461-
unscale_lora_layers(self.text_encoder, lora_scale)
459+
if self.text_encoder is not None:
460+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
461+
# Retrieve the original scale by scaling back the LoRA layers
462+
unscale_lora_layers(self.text_encoder, lora_scale)
462463

463464
return prompt_embeds, negative_prompt_embeds
464465

src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,9 +426,10 @@ def encode_prompt(
426426
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
427427
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
428428

429-
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
430-
# Retrieve the original scale by scaling back the LoRA layers
431-
unscale_lora_layers(self.text_encoder, lora_scale)
429+
if self.text_encoder is not None:
430+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
431+
# Retrieve the original scale by scaling back the LoRA layers
432+
unscale_lora_layers(self.text_encoder, lora_scale)
432433

433434
return prompt_embeds, negative_prompt_embeds
434435

src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,9 +364,10 @@ def encode_prompt(
364364
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
365365
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
366366

367-
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
368-
# Retrieve the original scale by scaling back the LoRA layers
369-
unscale_lora_layers(self.text_encoder, lora_scale)
367+
if self.text_encoder is not None:
368+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
369+
# Retrieve the original scale by scaling back the LoRA layers
370+
unscale_lora_layers(self.text_encoder, lora_scale)
370371

371372
return prompt_embeds, negative_prompt_embeds
372373

0 commit comments

Comments
 (0)