Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train Unconditional - LR Scheduler init fix #1957

Closed
mathieujouffroy opened this issue Jan 9, 2023 · 5 comments · Fixed by #2010
Closed

Train Unconditional - LR Scheduler init fix #1957

mathieujouffroy opened this issue Jan 9, 2023 · 5 comments · Fixed by #2010
Labels
bug Something isn't working

Comments

@mathieujouffroy
Copy link

mathieujouffroy commented Jan 9, 2023

Learning Rate Scheduler with gradient accumulation

Hi,

There seems to be a small error in the train_unconditional.py script regarding the initialization of the learning rate scheduler (lr_scheduler line 338).

As we are using gradient accumulation with Accelerate and global_step for logging, shouldn't we multiply num_warmup_steps by the gradient accumulation steps and set the number of training steps to len(train_dataloader) * args.num_epochs (instead of dividing by the gradient accumulation steps) ?
With the actual settings the updates to the learning rate don't seem correct when using a gradient_accumulation_steps higher than 1.

Therefore instead of having :

lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps,
        num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps
)

Shouldn't we set it to :

lr_scheduler = get_scheduler(
            args.lr_scheduler,
            optimizer=optimizer,
            num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
            num_training_steps=len(train_dataloader) * args.num_epochs
)

Thanks 🤗

Reproduction

accelerate launch train_unconditional.py \
  --dataset_name="huggan/flowers-102-categories" \
  --resolution=64 \
  --output_dir="ddpm-ema-flowers-64" \
  --train_batch_size=8 \
  --num_epochs=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=5e-5 \
  --lr_warmup_steps=80 \
  --mixed_precision=no 

With these settings, when inspecting the progress bar logs we can see that the learning rate (following the default cosine scheduler) increases to the initial learning_rate(5e-5) 4 times faster (which is the gradient_accumulation_steps) before decreasing towards 0. It reaches 5e-5 at step 20 instead of step 80.
Screenshot 2023-01-09 at 19 45 09

System Info

  • diffusers version: 0.11.1
  • Platform: Linux-5.4.0-109-generic-x86_64-with-glibc2.29
  • Python version: 3.8.12
  • PyTorch version (GPU?): 1.12.1+cu113 (True)
  • Huggingface_hub version: 0.10.1
  • Transformers version: 4.25.1
@mathieujouffroy mathieujouffroy added the bug Something isn't working label Jan 9, 2023
@johnowhitaker
Copy link
Contributor

@muellerzr do you happen to know how accelerate handles gradient accumulation with lr schedulers? Perhaps we don't need any of this manual adjustment at all (in which case some other scripts likely also need tweaking)

@mathieujouffroy
Copy link
Author

mathieujouffroy commented Jan 11, 2023

Hello, I may be missing out on something with Accelerate, but from what I've seen through a run of 1 epoch (256 steps) with gradient_accumulation_steps=4 and warmup_steps=80, following the cosine scheduler, the lr increases to the initial lr and decreases back to 0 -- gradient_accumulation_steps times -- faster, before continuously increasing and decreasing back until the end of the loop for step, batch in enumerate(train_dataloader).

Reproduction:

accelerate launch train_unconditional.py \
  --dataset_name="huggan/flowers-102-categories" \
  --resolution=64 \
  --output_dir="ddpm-ema-flowers-64" \
  --train_batch_size=8 \
  --num_epochs=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=5e-5 \
  --lr_warmup_steps=80 \
  --mixed_precision=no 

From my understanding, this is due to setting num_training_steps=(len(train_dataloader) * args.num_epochs) // args.gradient_accumulation_steps in the learning rate scheduler. Therefore, in my case, the learning rate scheduler ends at step 64 before restarting.

Here's the plot of the learning rate for that run :
Screenshot 2023-01-11 at 17 57 34

Am I missing out on something ?

@patrickvonplaten
Copy link
Contributor

cc @anton-l, @patil-suraj

@patil-suraj
Copy link
Contributor

@mathieujouffroy Thanks a lot for the detailed issue, and you are right. The num_training_steps should be set to len(train_dataloader) * args.num_epochs and accelerate will take care of calling the actual step at the right time with grad accumulation.

@mathieujouffroy
Copy link
Author

Hi @patil-suraj, ok cool 😊 thanks for your response !
I'm happy to help !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants