Skip to content

Train Unconditional - LR Scheduler init fix  #1957

Closed
@mathieujouffroy

Description

@mathieujouffroy

Learning Rate Scheduler with gradient accumulation

Hi,

There seems to be a small error in the train_unconditional.py script regarding the initialization of the learning rate scheduler (lr_scheduler line 338).

As we are using gradient accumulation with Accelerate and global_step for logging, shouldn't we multiply num_warmup_steps by the gradient accumulation steps and set the number of training steps to len(train_dataloader) * args.num_epochs (instead of dividing by the gradient accumulation steps) ?
With the actual settings the updates to the learning rate don't seem correct when using a gradient_accumulation_steps higher than 1.

Therefore instead of having :

lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps,
        num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps
)

Shouldn't we set it to :

lr_scheduler = get_scheduler(
            args.lr_scheduler,
            optimizer=optimizer,
            num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
            num_training_steps=len(train_dataloader) * args.num_epochs
)

Thanks 🤗

Reproduction

accelerate launch train_unconditional.py \
  --dataset_name="huggan/flowers-102-categories" \
  --resolution=64 \
  --output_dir="ddpm-ema-flowers-64" \
  --train_batch_size=8 \
  --num_epochs=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=5e-5 \
  --lr_warmup_steps=80 \
  --mixed_precision=no 

With these settings, when inspecting the progress bar logs we can see that the learning rate (following the default cosine scheduler) increases to the initial learning_rate(5e-5) 4 times faster (which is the gradient_accumulation_steps) before decreasing towards 0. It reaches 5e-5 at step 20 instead of step 80.
Screenshot 2023-01-09 at 19 45 09

System Info

  • diffusers version: 0.11.1
  • Platform: Linux-5.4.0-109-generic-x86_64-with-glibc2.29
  • Python version: 3.8.12
  • PyTorch version (GPU?): 1.12.1+cu113 (True)
  • Huggingface_hub version: 0.10.1
  • Transformers version: 4.25.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions