Skip to content

Commit

Permalink
[Experimental Feature] FP8 weight dtype for base model when running t…
Browse files Browse the repository at this point in the history
…rain_network (or sdxl_train_network) (#1057)

* Add fp8 support

* remove some debug prints

* Better implementation for te

* Fix some misunderstanding

* as same as unet, add explicit convert

* better impl for convert TE to fp8

* fp8 for not only unet

* Better cache TE and TE lr

* match arg name

* Fix with list

* Add timeout settings

* Fix arg style

* Add custom seperator

* Fix typo

* Fix typo again

* Fix dtype error

* Fix gradient problem

* Fix req grad

* fix merge

* Fix merge

* Resolve merge

* arrangement and document

* Resolve merge error

* Add assert for mixed precision
  • Loading branch information
KohakuBlueleaf committed Jan 20, 2024
1 parent 0395a35 commit 9cfa68c
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
3 changes: 3 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2904,6 +2904,9 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--full_bf16", action="store_true", help="bf16 training including gradients / 勾配も含めてbf16で学習する"
) # TODO move to SDXL training, because it is not supported by SD1/2
parser.add_argument(
"--fp8_base", action="store_true", help="use fp8 for base model / base modelにfp8を使う"
)
parser.add_argument(
"--ddp_timeout",
type=int,
Expand Down
36 changes: 30 additions & 6 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,16 +390,36 @@ def train(self, args):
accelerator.print("enable full bf16 training.")
network.to(weight_dtype)

unet_weight_dtype = te_weight_dtype = weight_dtype
# Experimental Feature: Put base model into fp8 to save vram
if args.fp8_base:
assert (
torch.__version__ >= '2.1.0'
), "fp8_base requires torch>=2.1.0 / fp8を使う場合はtorch>=2.1.0が必要です。"
assert (
args.mixed_precision != 'no'
), "fp8_base requires mixed precision='fp16' or 'bf16' / fp8を使う場合はmixed_precision='fp16'または'bf16'が必要です。"
accelerator.print("enable fp8 training.")
unet_weight_dtype = torch.float8_e4m3fn
te_weight_dtype = torch.float8_e4m3fn

unet.requires_grad_(False)
unet.to(dtype=weight_dtype)
unet.to(dtype=unet_weight_dtype)
for t_enc in text_encoders:
t_enc.requires_grad_(False)
t_enc.to(dtype=te_weight_dtype)
# nn.Embedding not support FP8
t_enc.text_model.embeddings.to(dtype=(
weight_dtype
if te_weight_dtype == torch.float8_e4m3fn
else te_weight_dtype
))

# acceleratorがなんかよろしくやってくれるらしい / accelerator will do something good
if train_unet:
unet = accelerator.prepare(unet)
else:
unet.to(accelerator.device, dtype=weight_dtype) # move to device because unet is not prepared by accelerator
unet.to(accelerator.device, dtype=unet_weight_dtype) # move to device because unet is not prepared by accelerator
if train_text_encoder:
if len(text_encoders) > 1:
text_encoder = text_encoders = [accelerator.prepare(t_enc) for t_enc in text_encoders]
Expand All @@ -421,9 +441,6 @@ def train(self, args):
if train_text_encoder:
t_enc.text_model.embeddings.requires_grad_(True)

# set top parameter requires_grad = True for gradient checkpointing works
if not train_text_encoder: # train U-Net only
unet.parameters().__next__().requires_grad_(True)
else:
unet.eval()
for t_enc in text_encoders:
Expand Down Expand Up @@ -778,10 +795,17 @@ def remove_model(old_ckpt_name):
args, noise_scheduler, latents
)

# ensure the hidden state will require grad
if args.gradient_checkpointing:
for x in noisy_latents:
x.requires_grad_(True)
for t in text_encoder_conds:
t.requires_grad_(True)

# Predict the noise residual
with accelerator.autocast():
noise_pred = self.call_unet(
args, accelerator, unet, noisy_latents, timesteps, text_encoder_conds, batch, weight_dtype
args, accelerator, unet, noisy_latents.requires_grad_(train_unet), timesteps, text_encoder_conds, batch, weight_dtype
)

if args.v_parameterization:
Expand Down

0 comments on commit 9cfa68c

Please sign in to comment.