-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Train Unconditional - LR Scheduler init fix #1957
Comments
@muellerzr do you happen to know how accelerate handles gradient accumulation with lr schedulers? Perhaps we don't need any of this manual adjustment at all (in which case some other scripts likely also need tweaking) |
cc @anton-l, @patil-suraj |
@mathieujouffroy Thanks a lot for the detailed issue, and you are right. The |
Hi @patil-suraj, ok cool 😊 thanks for your response ! |
Learning Rate Scheduler with gradient accumulation
Hi,
There seems to be a small error in the train_unconditional.py script regarding the initialization of the learning rate scheduler (
lr_scheduler
line 338).As we are using gradient accumulation with Accelerate and
global_step
for logging, shouldn't we multiplynum_warmup_steps
by the gradient accumulation steps and set the number of training steps tolen(train_dataloader) * args.num_epochs
(instead of dividing by the gradient accumulation steps) ?With the actual settings the updates to the learning rate don't seem correct when using a
gradient_accumulation_steps
higher than 1.Therefore instead of having :
Shouldn't we set it to :
Thanks 🤗
Reproduction
With these settings, when inspecting the progress bar logs we can see that the learning rate (following the default cosine scheduler) increases to the initial
learning_rate
(5e-5
) 4 times faster (which is thegradient_accumulation_steps
) before decreasing towards 0. It reaches5e-5
at step 20 instead of step 80.System Info
diffusers
version: 0.11.1The text was updated successfully, but these errors were encountered: