From fc8afa3ab5eb840ab0da5aadb629bf671eef9a39 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 24 Jan 2023 13:23:56 +0100 Subject: [PATCH] [dreambooth] fix multi on gpu. (#2088) unwrap model on multi gpu --- examples/dreambooth/train_dreambooth.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index b344cc98fd..3df55937c4 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -716,11 +716,16 @@ def main(args): " doing mixed precision training. copy of the weights should still be float32." ) - if unet.dtype != torch.float32: - raise ValueError(f"Unet loaded as datatype {unet.dtype}. {low_precision_error_string}") + if accelerator.unwrap_model(unet).dtype != torch.float32: + raise ValueError( + f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}" + ) - if args.train_text_encoder and text_encoder.dtype != torch.float32: - raise ValueError(f"Text encoder loaded as datatype {text_encoder.dtype}. {low_precision_error_string}") + if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32: + raise ValueError( + f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}." + f" {low_precision_error_string}" + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)