Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[train_text_to_image] allow using non-ema weights for training #1834

Merged
merged 17 commits into from
Dec 30, 2022

Conversation

patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Dec 26, 2022

This PR allows using non-ema weights for training and ema weights for EMA updates to mimic the original training process. For now, the workflow is as follows

  • Each pre-trained SD checkpoint will have a branch called non-ema.
  • The script allows specifying this using the --non_ema_revision argument. If it's None it will default to using ema weights for both training and ema as is the case now.
  • if --non_ema_revision is specified it will be used to load the unet for training and the ema (main) weights will be used for ema updates.

This approach of using branches is not the best solution but will be used until we have the variations feature in diffusers.

This PR also

  • allows checkpointing of the ema model. Currently only the train unet is checkpointed.
  • adds an argument --allow_tf32 to enable TF32 on Ampere GPUs (A100) for faster full-precision training. Gives about ~1.33x speed-up.

Example command:

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export NOM_EMA_REVISION="non-ema"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
export WANDB_PROJECT="stable-diffusion-pokemon"

accelerate launch --multi_gpu --gpu_ids="0,1" --mixed_precision="no" \
   ../diffusers/examples/text_to_image/train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --non_ema_revision=$NOM_EMA_REVISION --use_ema \
  --dataset_name=$DATASET_NAME --caption_column="text" \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=16 --gradient_checkpointing \
  --max_train_steps=5000 --checkpointing_steps=1000 \
  --learning_rate=3e-05 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \

Fixes #1153

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Dec 26, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks a lot!

However, I may not be fully understanding it yet. I know that state_dict and load_state_dict are used by accelerate during the checkpointing process, but I don't understand how store and restore are used. In addition, the line to resume from a checkpoint appears to have been removed, is resuming performed differently now?

examples/text_to_image/train_text_to_image.py Outdated Show resolved Hide resolved
examples/text_to_image/train_text_to_image.py Outdated Show resolved Hide resolved
examples/text_to_image/train_text_to_image.py Outdated Show resolved Hide resolved
temporarily stored. If `None`, the parameters of with which this
`ExponentialMovingAverage` was initialized will be used.
"""
parameters = list(parameters)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very minor question, why do we need the conversion to list here?

examples/text_to_image/train_text_to_image.py Outdated Show resolved Hide resolved
examples/text_to_image/train_text_to_image.py Outdated Show resolved Hide resolved
examples/text_to_image/train_text_to_image.py Outdated Show resolved Hide resolved
examples/text_to_image/train_text_to_image.py Outdated Show resolved Hide resolved
examples/text_to_image/train_text_to_image.py Outdated Show resolved Hide resolved
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
@patil-suraj
Copy link
Contributor Author

patil-suraj commented Dec 27, 2022

but I don't understand how store and restore are used.

The store and restore methods can be used for evaluation during training. For evaluation we want the EMA params, so we save the training params temporarily in EMAModel using the store, copy the ema params to the model, do the evaluation and then use the restore to restore the training parameters, back. But this is not currently used in the script, so will remove these methods.

In addition, the line to resume from a checkpoint appears to have been removed, is resuming performed differently now?

My bad, removed it by mistake.

@patil-suraj patil-suraj changed the title [wip][examples/train_text_to_image] allow using non-ema weights for training [train_text_to_image] allow using non-ema weights for training Dec 30, 2022
Comment on lines +473 to +474
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gives ~1.3x speed-up on A100.

@@ -541,7 +614,8 @@ def collate_fn(examples):
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
accelerator.register_for_checkpointing(lr_scheduler)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not required to register lr_scheduler here, it's automatically checkpointed by accelerate. We only need to register custom objects.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, thanks. This documentation led me to believe we needed to register it, but in those examples the learning rate scheduler is not being passed to prepare.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm should we maybe ask on the accelerate repo?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've verified this, all standard objects that we pass to prepare (like nn.Module, DataLoader, Optimizer, Scheduler) are automatically checkpointed by accelerate. We only need to register custom objects or models that we don't pass to prepare.

Comment on lines +575 to +577
inputs = tokenizer(
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

always padding to max_length now to completely match with original implem.

@patil-suraj patil-suraj requested a review from pcuenca December 30, 2022 14:12
Copy link
Member

@pcuenca pcuenca left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks clean! I trust @pcuenca and @patil-suraj here :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Why does train_text_to_image.py perform so differently from the CompVis script?
4 participants