-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train_text_to_image] allow using non-ema weights for training #1834
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great, thanks a lot!
However, I may not be fully understanding it yet. I know that state_dict
and load_state_dict
are used by accelerate
during the checkpointing process, but I don't understand how store
and restore
are used. In addition, the line to resume from a checkpoint appears to have been removed, is resuming performed differently now?
temporarily stored. If `None`, the parameters of with which this | ||
`ExponentialMovingAverage` was initialized will be used. | ||
""" | ||
parameters = list(parameters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very minor question, why do we need the conversion to list
here?
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
The
My bad, removed it by mistake. |
if args.allow_tf32: | ||
torch.backends.cuda.matmul.allow_tf32 = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives ~1.3x speed-up on A100.
@@ -541,7 +614,8 @@ def collate_fn(examples): | |||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | |||
unet, optimizer, train_dataloader, lr_scheduler | |||
) | |||
accelerator.register_for_checkpointing(lr_scheduler) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not required to register lr_scheduler
here, it's automatically checkpointed by accelerate
. We only need to register custom objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, thanks. This documentation led me to believe we needed to register it, but in those examples the learning rate scheduler is not being passed to prepare
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm should we maybe ask on the accelerate
repo?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've verified this, all standard objects that we pass to prepare
(like nn.Module
, DataLoader
, Optimizer
, Scheduler
) are automatically checkpointed by accelerate
. We only need to register custom objects or models that we don't pass to prepare
.
inputs = tokenizer( | ||
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
always padding to max_length
now to completely match with original implem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks clean! I trust @pcuenca and @patil-suraj here :-)
This PR allows using
non-ema
weights for training andema
weights for EMA updates to mimic the original training process. For now, the workflow is as followsnon-ema
.--non_ema_revision
argument. If it'sNone
it will default to using ema weights for both training and ema as is the case now.--non_ema_revision
is specified it will be used to load theunet
for training and the ema (main
) weights will be used for ema updates.This approach of using branches is not the best solution but will be used until we have the
variations
feature indiffusers
.This PR also
unet
is checkpointed.--allow_tf32
to enable TF32 on Ampere GPUs (A100) for faster full-precision training. Gives about ~1.33x speed-up.Example command:
Fixes #1153