From 9cfa68c92fc45fe6e78745d1ec5544b3d7b5cf5f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 20 Jan 2024 08:46:53 +0800 Subject: [PATCH] [Experimental Feature] FP8 weight dtype for base model when running train_network (or sdxl_train_network) (#1057) * Add fp8 support * remove some debug prints * Better implementation for te * Fix some misunderstanding * as same as unet, add explicit convert * better impl for convert TE to fp8 * fp8 for not only unet * Better cache TE and TE lr * match arg name * Fix with list * Add timeout settings * Fix arg style * Add custom seperator * Fix typo * Fix typo again * Fix dtype error * Fix gradient problem * Fix req grad * fix merge * Fix merge * Resolve merge * arrangement and document * Resolve merge error * Add assert for mixed precision --- library/train_util.py | 3 +++ train_network.py | 36 ++++++++++++++++++++++++++++++------ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ff161feab..21e7638da 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2904,6 +2904,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: parser.add_argument( "--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する" ) # TODO move to SDXL training, because it is not supported by SD1/2 + parser.add_argument( + "--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う" + ) parser.add_argument( "--ddp_timeout", type=int, diff --git a/train_network.py b/train_network.py index c2b7fbdef..5f28a5e0d 100644 --- a/train_network.py +++ b/train_network.py @@ -390,16 +390,36 @@ def train(self, args): accelerator.print("enable full bf16 training.") network.to(weight_dtype) + unet_weight_dtype = te_weight_dtype = weight_dtype + # Experimental Feature: Put base model into fp8 to save vram + if args.fp8_base: + assert ( + torch.__version__ >= '2.1.0' + ), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。" + assert ( + args.mixed_precision != 'no' + ), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。" + accelerator.print("enable fp8 training.") + unet_weight_dtype = torch.float8_e4m3fn + te_weight_dtype = torch.float8_e4m3fn + unet.requires_grad_(False) - unet.to(dtype=weight_dtype) + unet.to(dtype=unet_weight_dtype) for t_enc in text_encoders: t_enc.requires_grad_(False) + t_enc.to(dtype=te_weight_dtype) + # nn.Embedding not support FP8 + t_enc.text_model.embeddings.to(dtype=( + weight_dtype + if te_weight_dtype == torch.float8_e4m3fn + else te_weight_dtype + )) # acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good if train_unet: unet = accelerator.prepare(unet) else: - unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator + unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator if train_text_encoder: if len(text_encoders) > 1: text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders] @@ -421,9 +441,6 @@ def train(self, args): if train_text_encoder: t_enc.text_model.embeddings.requires_grad_(True) - # set top parameter requires_grad = True for gradient checkpointing works - if not train_text_encoder: # train U-Net only - unet.parameters().__next__().requires_grad_(True) else: unet.eval() for t_enc in text_encoders: @@ -778,10 +795,17 @@ def remove_model(old_ckpt_name): args, noise_scheduler, latents ) + # ensure the hidden state will require grad + if args.gradient_checkpointing: + for x in noisy_latents: + x.requires_grad_(True) + for t in text_encoder_conds: + t.requires_grad_(True) + # Predict the noise residual with accelerator.autocast(): noise_pred = self.call_unet( - args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype + args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype ) if args.v_parameterization: