-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[train_text_to_image] allow using non-ema weights for training #1834
Changes from 16 commits
32c7cde
a706c96
5918f66
103eaca
6f8654d
5c8f069
008014e
80b5c19
0171fcd
fe51c7c
50782eb
4079f5b
114969e
a54a56b
aae03ec
ad7ebe3
b29ed55
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import argparse | ||
import copy | ||
import logging | ||
import math | ||
import os | ||
|
@@ -11,6 +12,9 @@ | |
import torch.nn.functional as F | ||
import torch.utils.checkpoint | ||
|
||
import datasets | ||
import diffusers | ||
import transformers | ||
from accelerate import Accelerator | ||
from accelerate.logging import get_logger | ||
from accelerate.utils import set_seed | ||
|
@@ -28,7 +32,7 @@ | |
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. | ||
check_min_version("0.10.0.dev0") | ||
|
||
logger = get_logger(__name__) | ||
logger = get_logger(__name__, log_level="INFO") | ||
|
||
|
||
def parse_args(): | ||
|
@@ -171,7 +175,25 @@ def parse_args(): | |
parser.add_argument( | ||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
) | ||
parser.add_argument( | ||
"--allow_tf32", | ||
action="store_true", | ||
help=( | ||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" | ||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" | ||
), | ||
) | ||
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.") | ||
parser.add_argument( | ||
"--non_ema_revision", | ||
type=str, | ||
default=None, | ||
required=False, | ||
help=( | ||
"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" | ||
" remote repository specified with --pretrained_model_name_or_path." | ||
), | ||
) | ||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") | ||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") | ||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") | ||
|
@@ -247,6 +269,10 @@ def parse_args(): | |
if args.dataset_name is None and args.train_data_dir is None: | ||
raise ValueError("Need either a dataset name or a training folder.") | ||
|
||
# default to using the same revision for the non-ema model if not specified | ||
if args.non_ema_revision is None: | ||
args.non_ema_revision = args.revision | ||
|
||
return args | ||
|
||
|
||
|
@@ -275,6 +301,8 @@ def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): | |
parameters = list(parameters) | ||
self.shadow_params = [p.clone().detach() for p in parameters] | ||
|
||
self.collected_params = None | ||
|
||
self.decay = decay | ||
self.optimization_step = 0 | ||
|
||
|
@@ -322,6 +350,55 @@ def to(self, device=None, dtype=None) -> None: | |
for p in self.shadow_params | ||
] | ||
|
||
def state_dict(self) -> dict: | ||
r""" | ||
Returns the state of the ExponentialMovingAverage as a dict. | ||
This method is used by accelerate during checkpointing to save the ema state dict. | ||
""" | ||
# Following PyTorch conventions, references to tensors are returned: | ||
# "returns a reference to the state and not its copy!" - | ||
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict | ||
return { | ||
"decay": self.decay, | ||
"optimization_step": self.optimization_step, | ||
"shadow_params": self.shadow_params, | ||
"collected_params": self.collected_params, | ||
} | ||
|
||
def load_state_dict(self, state_dict: dict) -> None: | ||
r""" | ||
Loads the ExponentialMovingAverage state. | ||
This method is used by accelerate during checkpointing to save the ema state dict. | ||
Args: | ||
state_dict (dict): EMA state. Should be an object returned | ||
from a call to :meth:`state_dict`. | ||
""" | ||
# deepcopy, to be consistent with module API | ||
state_dict = copy.deepcopy(state_dict) | ||
|
||
self.decay = state_dict["decay"] | ||
if self.decay < 0.0 or self.decay > 1.0: | ||
raise ValueError("Decay must be between 0 and 1") | ||
|
||
self.optimization_step = state_dict["optimization_step"] | ||
if not isinstance(self.optimization_step, int): | ||
raise ValueError("Invalid optimization_step") | ||
|
||
self.shadow_params = state_dict["shadow_params"] | ||
if not isinstance(self.shadow_params, list): | ||
raise ValueError("shadow_params must be a list") | ||
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): | ||
raise ValueError("shadow_params must all be Tensors") | ||
|
||
self.collected_params = state_dict["collected_params"] | ||
if self.collected_params is not None: | ||
if not isinstance(self.collected_params, list): | ||
raise ValueError("collected_params must be a list") | ||
if not all(isinstance(p, torch.Tensor) for p in self.collected_params): | ||
raise ValueError("collected_params must all be Tensors") | ||
if len(self.collected_params) != len(self.shadow_params): | ||
raise ValueError("collected_params and shadow_params must have the same length") | ||
|
||
|
||
def main(): | ||
args = parse_args() | ||
|
@@ -339,6 +416,15 @@ def main(): | |
datefmt="%m/%d/%Y %H:%M:%S", | ||
level=logging.INFO, | ||
) | ||
logger.info(accelerator.state, main_process_only=False) | ||
if accelerator.is_local_main_process: | ||
datasets.utils.logging.set_verbosity_warning() | ||
transformers.utils.logging.set_verbosity_info() | ||
diffusers.utils.logging.set_verbosity_info() | ||
else: | ||
datasets.utils.logging.set_verbosity_error() | ||
transformers.utils.logging.set_verbosity_error() | ||
diffusers.utils.logging.set_verbosity_error() | ||
|
||
# If passed along, set the training seed now. | ||
if args.seed is not None: | ||
|
@@ -361,39 +447,44 @@ def main(): | |
elif args.output_dir is not None: | ||
os.makedirs(args.output_dir, exist_ok=True) | ||
|
||
# Load models and create wrapper for stable diffusion | ||
# Load scheduler, tokenizer and models. | ||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") | ||
tokenizer = CLIPTokenizer.from_pretrained( | ||
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision | ||
) | ||
text_encoder = CLIPTextModel.from_pretrained( | ||
args.pretrained_model_name_or_path, | ||
subfolder="text_encoder", | ||
revision=args.revision, | ||
) | ||
vae = AutoencoderKL.from_pretrained( | ||
args.pretrained_model_name_or_path, | ||
subfolder="vae", | ||
revision=args.revision, | ||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision | ||
) | ||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) | ||
unet = UNet2DConditionModel.from_pretrained( | ||
args.pretrained_model_name_or_path, | ||
subfolder="unet", | ||
revision=args.revision, | ||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision | ||
) | ||
|
||
# Freeze vae and text_encoder | ||
vae.requires_grad_(False) | ||
text_encoder.requires_grad_(False) | ||
|
||
# Create EMA for the unet. | ||
if args.use_ema: | ||
ema_unet = UNet2DConditionModel.from_pretrained( | ||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision | ||
) | ||
ema_unet = EMAModel(ema_unet.parameters()) | ||
|
||
if args.enable_xformers_memory_efficient_attention: | ||
if is_xformers_available(): | ||
unet.enable_xformers_memory_efficient_attention() | ||
else: | ||
raise ValueError("xformers is not available. Make sure it is installed correctly") | ||
|
||
# Freeze vae and text_encoder | ||
vae.requires_grad_(False) | ||
text_encoder.requires_grad_(False) | ||
|
||
if args.gradient_checkpointing: | ||
unet.enable_gradient_checkpointing() | ||
|
||
# Enable TF32 for faster training on Ampere GPUs, | ||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices | ||
if args.allow_tf32: | ||
torch.backends.cuda.matmul.allow_tf32 = True | ||
|
||
if args.scale_lr: | ||
args.learning_rate = ( | ||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes | ||
|
@@ -419,7 +510,6 @@ def main(): | |
weight_decay=args.adam_weight_decay, | ||
eps=args.adam_epsilon, | ||
) | ||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") | ||
|
||
# Get the datasets: you can either provide your own training and evaluation files (see below) | ||
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). | ||
|
@@ -482,9 +572,10 @@ def tokenize_captions(examples, is_train=True): | |
raise ValueError( | ||
f"Caption column `{caption_column}` should contain either strings or lists of strings." | ||
) | ||
inputs = tokenizer(captions, max_length=tokenizer.model_max_length, padding="do_not_pad", truncation=True) | ||
input_ids = inputs.input_ids | ||
return input_ids | ||
inputs = tokenizer( | ||
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" | ||
) | ||
Comment on lines
+575
to
+577
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. always padding to |
||
return inputs.input_ids | ||
|
||
train_transforms = transforms.Compose( | ||
[ | ||
|
@@ -500,7 +591,6 @@ def preprocess_train(examples): | |
images = [image.convert("RGB") for image in examples[image_column]] | ||
examples["pixel_values"] = [train_transforms(image) for image in images] | ||
examples["input_ids"] = tokenize_captions(examples) | ||
|
||
return examples | ||
|
||
with accelerator.main_process_first(): | ||
|
@@ -512,13 +602,8 @@ def preprocess_train(examples): | |
def collate_fn(examples): | ||
pixel_values = torch.stack([example["pixel_values"] for example in examples]) | ||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() | ||
input_ids = [example["input_ids"] for example in examples] | ||
padded_tokens = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt") | ||
return { | ||
"pixel_values": pixel_values, | ||
"input_ids": padded_tokens.input_ids, | ||
"attention_mask": padded_tokens.attention_mask, | ||
} | ||
input_ids = torch.stack([example["input_ids"] for example in examples]) | ||
return {"pixel_values": pixel_values, "input_ids": input_ids} | ||
|
||
train_dataloader = torch.utils.data.DataLoader( | ||
train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size | ||
|
@@ -541,7 +626,8 @@ def collate_fn(examples): | |
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( | ||
unet, optimizer, train_dataloader, lr_scheduler | ||
) | ||
accelerator.register_for_checkpointing(lr_scheduler) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not required to register There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting, thanks. This documentation led me to believe we needed to register it, but in those examples the learning rate scheduler is not being passed to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm should we maybe ask on the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've verified this, all standard objects that we pass to |
||
if args.use_ema: | ||
accelerator.register_for_checkpointing(ema_unet) | ||
|
||
weight_dtype = torch.float32 | ||
if accelerator.mixed_precision == "fp16": | ||
|
@@ -554,10 +640,8 @@ def collate_fn(examples): | |
# as these models are only used for inference, keeping weights in full precision is not required. | ||
text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
vae.to(accelerator.device, dtype=weight_dtype) | ||
|
||
# Create EMA for the unet. | ||
if args.use_ema: | ||
ema_unet = EMAModel(unet.parameters()) | ||
ema_unet.to(accelerator.device) | ||
|
||
# We need to recalculate our total training steps as the size of the training dataloader may have changed. | ||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives ~1.3x speed-up on A100.