From 47f44ffa86de7f91b90ac9046677a7931803974f Mon Sep 17 00:00:00 2001 From: Z-Fran <1396925302@qq.com> Date: Sun, 25 Jun 2023 20:03:53 +0800 Subject: [PATCH 1/3] [Fix] fix sd and controlnet fp16 bugs --- .../models/editors/controlnet/controlnet.py | 32 +++++++++++++++++-- .../stable_diffusion/stable_diffusion.py | 19 +++++++++++ 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/mmagic/models/editors/controlnet/controlnet.py b/mmagic/models/editors/controlnet/controlnet.py index f752d58ecf..6c3f6c977e 100644 --- a/mmagic/models/editors/controlnet/controlnet.py +++ b/mmagic/models/editors/controlnet/controlnet.py @@ -80,8 +80,10 @@ def __init__(self, default_args = dict() if dtype is not None: default_args['dtype'] = dtype - self.controlnet = build_module( - controlnet, MODELS, default_args=default_args) + + # NOTE: initialize controlnet as fp32 + self.controlnet = build_module(controlnet, MODELS) + self._controlnet_ori_dtype = next(self.controlnet.parameters()).dtype self.set_xformers(self.controlnet) self.vae.requires_grad_(False) @@ -366,6 +368,25 @@ def prepare_control(image: Tuple[Image.Image, List[Image.Image], Tensor, return image + def train(self, mode: bool = True): + """Set train/eval mode. + + Args: + mode (bool, optional): Whether set train mode. Defaults to True. + """ + if mode: + self.controlnet.to(self._controlnet_ori_dtype) + print_log( + 'Set ControlNetModel dtype to ' + f'\'{self._controlnet_ori_dtype}\' in the train mode.', + 'current') + else: + self.controlnet.to(self.dtype) + print_log( + f'Set ControlNetModel dtype to \'{self.dtype}\' ' + 'in the eval mode.', 'current') + return super().train(mode) + @torch.no_grad() def infer(self, prompt: Union[str, List[str]], @@ -448,6 +469,8 @@ def infer(self, # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + img_dtype = self.vae.module.dtype if hasattr(self.vae, 'module') \ + else self.vae.dtype if is_model_wrapper(self.controlnet): control_dtype = self.controlnet.module.dtype else: @@ -500,6 +523,9 @@ def infer(self, latent_model_input = self.test_scheduler.scale_model_input( latent_model_input, t) + latent_model_input = latent_model_input.to(control_dtype) + text_embeddings = text_embeddings.to(control_dtype) + down_block_res_samples, mid_block_res_sample = self.controlnet( latent_model_input, t, @@ -534,7 +560,7 @@ def infer(self, noise_pred, t, latents, **extra_step_kwargs)['prev_sample'] # 8. Post-processing - image = self.decode_latents(latents) + image = self.decode_latents(latents.to(img_dtype)) if do_classifier_free_guidance: controls = torch.split(controls, controls.shape[0] // 2, dim=0)[0] diff --git a/mmagic/models/editors/stable_diffusion/stable_diffusion.py b/mmagic/models/editors/stable_diffusion/stable_diffusion.py index 38cbdff289..673dd9e072 100644 --- a/mmagic/models/editors/stable_diffusion/stable_diffusion.py +++ b/mmagic/models/editors/stable_diffusion/stable_diffusion.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from mmengine import print_log from mmengine.logging import MMLogger from mmengine.model import BaseModel from mmengine.runner import set_random_seed @@ -90,6 +91,7 @@ def __init__(self, self.vae = build_module(vae, MODELS, default_args=default_args) self.unet = build_module(unet, MODELS) # NOTE: initialize unet as fp32 + self._unet_ori_dtype = next(self.unet.parameters()).dtype self.scheduler = build_module(scheduler, DIFFUSION_SCHEDULERS) if test_scheduler is None: self.test_scheduler = deepcopy(self.scheduler) @@ -140,6 +142,23 @@ def set_tomesd(self) -> nn.Module: def device(self): return next(self.parameters()).device + def train(self, mode: bool = True): + """Set train/eval mode. + + Args: + mode (bool, optional): Whether set train mode. Defaults to True. + """ + if mode: + self.unet.to(self._unet_ori_dtype) + print_log( + f'Set UNet dtype to \'{self._unet_ori_dtype}\' ' + 'in the train mode.', 'current') + else: + self.unet.to(self.dtype) + print_log(f'Set UNet dtype to \'{self.dtype}\' in the eval mode.', + 'current') + return super().train(mode) + @torch.no_grad() def infer(self, prompt: Union[str, List[str]], From 74275c9de8249d4580f451f7bcc8d5e143cbd656 Mon Sep 17 00:00:00 2001 From: Z-Fran <1396925302@qq.com> Date: Mon, 26 Jun 2023 16:10:51 +0800 Subject: [PATCH 2/3] fix --- .../models/editors/controlnet/controlnet.py | 56 +++++++++---------- .../stable_diffusion/stable_diffusion.py | 29 ++++------ 2 files changed, 38 insertions(+), 47 deletions(-) diff --git a/mmagic/models/editors/controlnet/controlnet.py b/mmagic/models/editors/controlnet/controlnet.py index 6c3f6c977e..7ca54900e1 100644 --- a/mmagic/models/editors/controlnet/controlnet.py +++ b/mmagic/models/editors/controlnet/controlnet.py @@ -83,7 +83,6 @@ def __init__(self, # NOTE: initialize controlnet as fp32 self.controlnet = build_module(controlnet, MODELS) - self._controlnet_ori_dtype = next(self.controlnet.parameters()).dtype self.set_xformers(self.controlnet) self.vae.requires_grad_(False) @@ -203,6 +202,7 @@ def train_step(self, data: dict, num_batches = target.shape[0] + target = target.to(self.dtype) latents = self.vae.encode(target).latent_dist.sample() latents = latents * self.vae.config.scaling_factor @@ -233,19 +233,20 @@ def train_step(self, data: dict, f'{self.scheduler.config.prediction_type}') # forward control + # NOTE: we train controlnet in fp32, convert to float manually down_block_res_samples, mid_block_res_sample = self.controlnet( - noisy_latents, + noisy_latents.float(), timesteps, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=control, + encoder_hidden_states=encoder_hidden_states.float(), + controlnet_cond=control.float(), return_dict=False, ) - # Predict the noise residual and compute loss + # NOTE: we train unet in fp32, convert to float manually model_output = self.unet( - noisy_latents, + noisy_latents.float(), timesteps, - encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states=encoder_hidden_states.float(), down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample) model_pred = model_output['sample'] @@ -270,10 +271,19 @@ def val_step(self, data: dict) -> SampleList: data = self.data_preprocessor(data) prompt = data['data_samples'].prompt control = data['inputs']['source'] + + unet_dtype = next(self.unet.parameters()).dtype + self.unet.to(self.dtype) + controlnet_dtype = next(self.controlnet.parameters()).dtype + self.controlnet.to(self.dtype) + output = self.infer( prompt, control=((control + 1) / 2), return_type='tensor') - samples = output['samples'] + + self.unet.to(unet_dtype) + self.controlnet.to(controlnet_dtype) + samples = self.data_preprocessor.destruct( samples, data['data_samples'], key='target') control = self.data_preprocessor.destruct( @@ -300,10 +310,19 @@ def test_step(self, data: dict) -> SampleList: data = self.data_preprocessor(data) prompt = data['data_samples'].prompt control = data['inputs']['source'] + + unet_dtype = next(self.unet.parameters()).dtype + self.unet.to(self.dtype) + controlnet_dtype = next(self.controlnet.parameters()).dtype + self.controlnet.to(self.dtype) + output = self.infer( prompt, control=((control + 1) / 2), return_type='tensor') - samples = output['samples'] + + self.unet.to(unet_dtype) + self.controlnet.to(controlnet_dtype) + samples = self.data_preprocessor.destruct( samples, data['data_samples'], key='target') control = self.data_preprocessor.destruct( @@ -368,25 +387,6 @@ def prepare_control(image: Tuple[Image.Image, List[Image.Image], Tensor, return image - def train(self, mode: bool = True): - """Set train/eval mode. - - Args: - mode (bool, optional): Whether set train mode. Defaults to True. - """ - if mode: - self.controlnet.to(self._controlnet_ori_dtype) - print_log( - 'Set ControlNetModel dtype to ' - f'\'{self._controlnet_ori_dtype}\' in the train mode.', - 'current') - else: - self.controlnet.to(self.dtype) - print_log( - f'Set ControlNetModel dtype to \'{self.dtype}\' ' - 'in the eval mode.', 'current') - return super().train(mode) - @torch.no_grad() def infer(self, prompt: Union[str, List[str]], diff --git a/mmagic/models/editors/stable_diffusion/stable_diffusion.py b/mmagic/models/editors/stable_diffusion/stable_diffusion.py index 673dd9e072..0589fb03a4 100644 --- a/mmagic/models/editors/stable_diffusion/stable_diffusion.py +++ b/mmagic/models/editors/stable_diffusion/stable_diffusion.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -from mmengine import print_log from mmengine.logging import MMLogger from mmengine.model import BaseModel from mmengine.runner import set_random_seed @@ -91,7 +90,6 @@ def __init__(self, self.vae = build_module(vae, MODELS, default_args=default_args) self.unet = build_module(unet, MODELS) # NOTE: initialize unet as fp32 - self._unet_ori_dtype = next(self.unet.parameters()).dtype self.scheduler = build_module(scheduler, DIFFUSION_SCHEDULERS) if test_scheduler is None: self.test_scheduler = deepcopy(self.scheduler) @@ -142,23 +140,6 @@ def set_tomesd(self) -> nn.Module: def device(self): return next(self.parameters()).device - def train(self, mode: bool = True): - """Set train/eval mode. - - Args: - mode (bool, optional): Whether set train mode. Defaults to True. - """ - if mode: - self.unet.to(self._unet_ori_dtype) - print_log( - f'Set UNet dtype to \'{self._unet_ori_dtype}\' ' - 'in the train mode.', 'current') - else: - self.unet.to(self.dtype) - print_log(f'Set UNet dtype to \'{self.dtype}\' in the eval mode.', - 'current') - return super().train(mode) - @torch.no_grad() def infer(self, prompt: Union[str, List[str]], @@ -591,9 +572,14 @@ def val_step(self, data: dict) -> SampleList: data_samples = data['data_samples'] prompt = data_samples.prompt + unet_dtype = next(self.unet.parameters()).dtype + self.unet.to(self.dtype) + output = self.infer(prompt, return_type='tensor') samples = output['samples'] + self.unet.to(unet_dtype) + samples = self.data_preprocessor.destruct(samples, data_samples) gt_img = self.data_preprocessor.destruct(data['inputs'], data_samples) @@ -610,9 +596,14 @@ def test_step(self, data: dict) -> SampleList: data_samples = data['data_samples'] prompt = data_samples.prompt + unet_dtype = next(self.unet.parameters()).dtype + self.unet.to(self.dtype) + output = self.infer(prompt, return_type='tensor') samples = output['samples'] + self.unet.to(unet_dtype) + samples = self.data_preprocessor.destruct(samples, data_samples) gt_img = self.data_preprocessor.destruct(data['inputs'], data_samples) From 720c1767e6c8b7e2357abb2934423a5f5d20e11c Mon Sep 17 00:00:00 2001 From: Z-Fran <1396925302@qq.com> Date: Mon, 26 Jun 2023 18:29:28 +0800 Subject: [PATCH 3/3] fix --- .../controlnet_animation_inferencer.py | 2 +- .../models/editors/controlnet/controlnet.py | 50 ++++++++++++------- .../models/editors/dreambooth/dreambooth.py | 12 ----- .../stable_diffusion/stable_diffusion.py | 31 ++++++++---- 4 files changed, 55 insertions(+), 40 deletions(-) diff --git a/mmagic/apis/inferencers/controlnet_animation_inferencer.py b/mmagic/apis/inferencers/controlnet_animation_inferencer.py index 56bb98fb67..3c811b6aae 100644 --- a/mmagic/apis/inferencers/controlnet_animation_inferencer.py +++ b/mmagic/apis/inferencers/controlnet_animation_inferencer.py @@ -80,7 +80,7 @@ def __init__(self, self.inference_method = cfg.inference_method if self.inference_method == 'attention_injection': cfg.model.attention_injection = True - self.pipe = MODELS.build(cfg.model).cuda() + self.pipe = MODELS.build(cfg.model).cuda().eval() control_scheduler_cfg = dict( type=cfg.control_scheduler, diff --git a/mmagic/models/editors/controlnet/controlnet.py b/mmagic/models/editors/controlnet/controlnet.py index 7ca54900e1..dfce2c44ed 100644 --- a/mmagic/models/editors/controlnet/controlnet.py +++ b/mmagic/models/editors/controlnet/controlnet.py @@ -83,6 +83,10 @@ def __init__(self, # NOTE: initialize controlnet as fp32 self.controlnet = build_module(controlnet, MODELS) + self._controlnet_ori_dtype = next(self.controlnet.parameters()).dtype + print_log( + 'Set ControlNetModel dtype to ' + f'\'{self._controlnet_ori_dtype}\'.', 'current') self.set_xformers(self.controlnet) self.vae.requires_grad_(False) @@ -272,18 +276,10 @@ def val_step(self, data: dict) -> SampleList: prompt = data['data_samples'].prompt control = data['inputs']['source'] - unet_dtype = next(self.unet.parameters()).dtype - self.unet.to(self.dtype) - controlnet_dtype = next(self.controlnet.parameters()).dtype - self.controlnet.to(self.dtype) - output = self.infer( prompt, control=((control + 1) / 2), return_type='tensor') samples = output['samples'] - self.unet.to(unet_dtype) - self.controlnet.to(controlnet_dtype) - samples = self.data_preprocessor.destruct( samples, data['data_samples'], key='target') control = self.data_preprocessor.destruct( @@ -311,18 +307,10 @@ def test_step(self, data: dict) -> SampleList: prompt = data['data_samples'].prompt control = data['inputs']['source'] - unet_dtype = next(self.unet.parameters()).dtype - self.unet.to(self.dtype) - controlnet_dtype = next(self.controlnet.parameters()).dtype - self.controlnet.to(self.dtype) - output = self.infer( prompt, control=((control + 1) / 2), return_type='tensor') samples = output['samples'] - self.unet.to(unet_dtype) - self.controlnet.to(controlnet_dtype) - samples = self.data_preprocessor.destruct( samples, data['data_samples'], key='target') control = self.data_preprocessor.destruct( @@ -387,6 +375,27 @@ def prepare_control(image: Tuple[Image.Image, List[Image.Image], Tensor, return image + def train(self, mode: bool = True): + """Set train/eval mode. + + Args: + mode (bool, optional): Whether set train mode. Defaults to True. + """ + if mode: + if next(self.controlnet.parameters() + ).dtype != self._controlnet_ori_dtype: + print_log( + 'Set ControlNetModel dtype to ' + f'\'{self._controlnet_ori_dtype}\' in the train mode.', + 'current') + self.controlnet.to(self._controlnet_ori_dtype) + else: + self.controlnet.to(self.dtype) + print_log( + f'Set ControlNetModel dtype to \'{self.dtype}\' ' + 'in the eval mode.', 'current') + return super().train(mode) + @torch.no_grad() def infer(self, prompt: Union[str, List[str]], @@ -791,6 +800,8 @@ def infer( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + img_dtype = self.vae.module.dtype if hasattr( + self.vae, 'module') else self.vae.dtype if is_model_wrapper(self.controlnet): control_dtype = self.controlnet.module.dtype else: @@ -818,6 +829,7 @@ def infer( num_images_per_prompt, do_classifier_free_guidance, negative_prompt) + text_embeddings = text_embeddings.to(control_dtype) # 4. Prepare timesteps # self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -866,6 +878,8 @@ def infer( latent_model_input = self.test_scheduler.scale_model_input( latent_model_input, t) + latent_model_input = latent_model_input.to(control_dtype) + if reference_img is not None: ref_img_vae_latents_t = self.scheduler.add_noise( ref_img_vae_latents, torch.randn_like(ref_img_vae_latents), @@ -876,6 +890,8 @@ def infer( ref_img_vae_latents_model_input = \ self.test_scheduler.scale_model_input( ref_img_vae_latents_model_input, t) + ref_img_vae_latents_model_input = \ + ref_img_vae_latents_model_input.to(control_dtype) down_block_res_samples, mid_block_res_sample = self.controlnet( latent_model_input, @@ -924,7 +940,7 @@ def infer( vae_encode_latents * (1.0 - latent_mask) # 8. Post-processing - image = self.decode_latents(latents) + image = self.decode_latents(latents.to(img_dtype)) if do_classifier_free_guidance: controls = torch.split(controls, controls.shape[0] // 2, dim=0)[0] diff --git a/mmagic/models/editors/dreambooth/dreambooth.py b/mmagic/models/editors/dreambooth/dreambooth.py index 0ad2af3a8f..a711f16d8b 100644 --- a/mmagic/models/editors/dreambooth/dreambooth.py +++ b/mmagic/models/editors/dreambooth/dreambooth.py @@ -191,14 +191,8 @@ def val_step(self, data: dict) -> SampleList: data_samples.split() * len(prompt) data_samples = DataSample.stack(data_samples.split() * len(prompt)) - unet_dtype = next(self.unet.parameters()).dtype - self.unet.to(self.dtype) - output = self.infer(prompt, return_type='tensor') samples = output['samples'] - - self.unet.to(unet_dtype) - samples = self.data_preprocessor.destruct(samples, data_samples) out_data_sample = DataSample(fake_img=samples, prompt=prompt) @@ -226,14 +220,8 @@ def test_step(self, data: dict) -> SampleList: # construct a fake data_sample for destruct data_samples = DataSample.stack(data['data_samples'] * len(prompt)) - unet_dtype = next(self.unet.parameters()).dtype - self.unet.to(self.dtype) - output = self.infer(prompt, return_type='tensor') samples = output['samples'] - - self.unet.to(unet_dtype) - samples = self.data_preprocessor.destruct(samples, data_samples) out_data_sample = DataSample(fake_img=samples, prompt=prompt) diff --git a/mmagic/models/editors/stable_diffusion/stable_diffusion.py b/mmagic/models/editors/stable_diffusion/stable_diffusion.py index 0589fb03a4..02cd956fcc 100644 --- a/mmagic/models/editors/stable_diffusion/stable_diffusion.py +++ b/mmagic/models/editors/stable_diffusion/stable_diffusion.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from mmengine import print_log from mmengine.logging import MMLogger from mmengine.model import BaseModel from mmengine.runner import set_random_seed @@ -90,6 +91,8 @@ def __init__(self, self.vae = build_module(vae, MODELS, default_args=default_args) self.unet = build_module(unet, MODELS) # NOTE: initialize unet as fp32 + self._unet_ori_dtype = next(self.unet.parameters()).dtype + print_log(f'Set UNet dtype to \'{self._unet_ori_dtype}\'.', 'current') self.scheduler = build_module(scheduler, DIFFUSION_SCHEDULERS) if test_scheduler is None: self.test_scheduler = deepcopy(self.scheduler) @@ -140,6 +143,24 @@ def set_tomesd(self) -> nn.Module: def device(self): return next(self.parameters()).device + def train(self, mode: bool = True): + """Set train/eval mode. + + Args: + mode (bool, optional): Whether set train mode. Defaults to True. + """ + if mode: + if next(self.unet.parameters()).dtype != self._unet_ori_dtype: + print_log( + f'Set UNet dtype to \'{self._unet_ori_dtype}\' ' + 'in the train mode.', 'current') + self.unet.to(self._unet_ori_dtype) + else: + self.unet.to(self.dtype) + print_log(f'Set UNet dtype to \'{self.dtype}\' in the eval mode.', + 'current') + return super().train(mode) + @torch.no_grad() def infer(self, prompt: Union[str, List[str]], @@ -572,14 +593,9 @@ def val_step(self, data: dict) -> SampleList: data_samples = data['data_samples'] prompt = data_samples.prompt - unet_dtype = next(self.unet.parameters()).dtype - self.unet.to(self.dtype) - output = self.infer(prompt, return_type='tensor') samples = output['samples'] - self.unet.to(unet_dtype) - samples = self.data_preprocessor.destruct(samples, data_samples) gt_img = self.data_preprocessor.destruct(data['inputs'], data_samples) @@ -596,14 +612,9 @@ def test_step(self, data: dict) -> SampleList: data_samples = data['data_samples'] prompt = data_samples.prompt - unet_dtype = next(self.unet.parameters()).dtype - self.unet.to(self.dtype) - output = self.infer(prompt, return_type='tensor') samples = output['samples'] - self.unet.to(unet_dtype) - samples = self.data_preprocessor.destruct(samples, data_samples) gt_img = self.data_preprocessor.destruct(data['inputs'], data_samples)