diff --git a/mmagic/apis/inferencers/controlnet_animation_inferencer.py b/mmagic/apis/inferencers/controlnet_animation_inferencer.py index 56bb98fb67..3c811b6aae 100644 --- a/mmagic/apis/inferencers/controlnet_animation_inferencer.py +++ b/mmagic/apis/inferencers/controlnet_animation_inferencer.py @@ -80,7 +80,7 @@ def __init__(self, self.inference_method = cfg.inference_method if self.inference_method == 'attention_injection': cfg.model.attention_injection = True - self.pipe = MODELS.build(cfg.model).cuda() + self.pipe = MODELS.build(cfg.model).cuda().eval() control_scheduler_cfg = dict( type=cfg.control_scheduler, diff --git a/mmagic/models/editors/controlnet/controlnet.py b/mmagic/models/editors/controlnet/controlnet.py index f752d58ecf..dfce2c44ed 100644 --- a/mmagic/models/editors/controlnet/controlnet.py +++ b/mmagic/models/editors/controlnet/controlnet.py @@ -80,8 +80,13 @@ def __init__(self, default_args = dict() if dtype is not None: default_args['dtype'] = dtype - self.controlnet = build_module( - controlnet, MODELS, default_args=default_args) + + # NOTE: initialize controlnet as fp32 + self.controlnet = build_module(controlnet, MODELS) + self._controlnet_ori_dtype = next(self.controlnet.parameters()).dtype + print_log( + 'Set ControlNetModel dtype to ' + f'\'{self._controlnet_ori_dtype}\'.', 'current') self.set_xformers(self.controlnet) self.vae.requires_grad_(False) @@ -201,6 +206,7 @@ def train_step(self, data: dict, num_batches = target.shape[0] + target = target.to(self.dtype) latents = self.vae.encode(target).latent_dist.sample() latents = latents * self.vae.config.scaling_factor @@ -231,19 +237,20 @@ def train_step(self, data: dict, f'{self.scheduler.config.prediction_type}') # forward control + # NOTE: we train controlnet in fp32, convert to float manually down_block_res_samples, mid_block_res_sample = self.controlnet( - noisy_latents, + noisy_latents.float(), timesteps, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=control, + encoder_hidden_states=encoder_hidden_states.float(), + controlnet_cond=control.float(), return_dict=False, ) - # Predict the noise residual and compute loss + # NOTE: we train unet in fp32, convert to float manually model_output = self.unet( - noisy_latents, + noisy_latents.float(), timesteps, - encoder_hidden_states=encoder_hidden_states, + encoder_hidden_states=encoder_hidden_states.float(), down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample) model_pred = model_output['sample'] @@ -268,10 +275,11 @@ def val_step(self, data: dict) -> SampleList: data = self.data_preprocessor(data) prompt = data['data_samples'].prompt control = data['inputs']['source'] + output = self.infer( prompt, control=((control + 1) / 2), return_type='tensor') - samples = output['samples'] + samples = self.data_preprocessor.destruct( samples, data['data_samples'], key='target') control = self.data_preprocessor.destruct( @@ -298,10 +306,11 @@ def test_step(self, data: dict) -> SampleList: data = self.data_preprocessor(data) prompt = data['data_samples'].prompt control = data['inputs']['source'] + output = self.infer( prompt, control=((control + 1) / 2), return_type='tensor') - samples = output['samples'] + samples = self.data_preprocessor.destruct( samples, data['data_samples'], key='target') control = self.data_preprocessor.destruct( @@ -366,6 +375,27 @@ def prepare_control(image: Tuple[Image.Image, List[Image.Image], Tensor, return image + def train(self, mode: bool = True): + """Set train/eval mode. + + Args: + mode (bool, optional): Whether set train mode. Defaults to True. + """ + if mode: + if next(self.controlnet.parameters() + ).dtype != self._controlnet_ori_dtype: + print_log( + 'Set ControlNetModel dtype to ' + f'\'{self._controlnet_ori_dtype}\' in the train mode.', + 'current') + self.controlnet.to(self._controlnet_ori_dtype) + else: + self.controlnet.to(self.dtype) + print_log( + f'Set ControlNetModel dtype to \'{self.dtype}\' ' + 'in the eval mode.', 'current') + return super().train(mode) + @torch.no_grad() def infer(self, prompt: Union[str, List[str]], @@ -448,6 +478,8 @@ def infer(self, # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + img_dtype = self.vae.module.dtype if hasattr(self.vae, 'module') \ + else self.vae.dtype if is_model_wrapper(self.controlnet): control_dtype = self.controlnet.module.dtype else: @@ -500,6 +532,9 @@ def infer(self, latent_model_input = self.test_scheduler.scale_model_input( latent_model_input, t) + latent_model_input = latent_model_input.to(control_dtype) + text_embeddings = text_embeddings.to(control_dtype) + down_block_res_samples, mid_block_res_sample = self.controlnet( latent_model_input, t, @@ -534,7 +569,7 @@ def infer(self, noise_pred, t, latents, **extra_step_kwargs)['prev_sample'] # 8. Post-processing - image = self.decode_latents(latents) + image = self.decode_latents(latents.to(img_dtype)) if do_classifier_free_guidance: controls = torch.split(controls, controls.shape[0] // 2, dim=0)[0] @@ -765,6 +800,8 @@ def infer( # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 + img_dtype = self.vae.module.dtype if hasattr( + self.vae, 'module') else self.vae.dtype if is_model_wrapper(self.controlnet): control_dtype = self.controlnet.module.dtype else: @@ -792,6 +829,7 @@ def infer( num_images_per_prompt, do_classifier_free_guidance, negative_prompt) + text_embeddings = text_embeddings.to(control_dtype) # 4. Prepare timesteps # self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -840,6 +878,8 @@ def infer( latent_model_input = self.test_scheduler.scale_model_input( latent_model_input, t) + latent_model_input = latent_model_input.to(control_dtype) + if reference_img is not None: ref_img_vae_latents_t = self.scheduler.add_noise( ref_img_vae_latents, torch.randn_like(ref_img_vae_latents), @@ -850,6 +890,8 @@ def infer( ref_img_vae_latents_model_input = \ self.test_scheduler.scale_model_input( ref_img_vae_latents_model_input, t) + ref_img_vae_latents_model_input = \ + ref_img_vae_latents_model_input.to(control_dtype) down_block_res_samples, mid_block_res_sample = self.controlnet( latent_model_input, @@ -898,7 +940,7 @@ def infer( vae_encode_latents * (1.0 - latent_mask) # 8. Post-processing - image = self.decode_latents(latents) + image = self.decode_latents(latents.to(img_dtype)) if do_classifier_free_guidance: controls = torch.split(controls, controls.shape[0] // 2, dim=0)[0] diff --git a/mmagic/models/editors/dreambooth/dreambooth.py b/mmagic/models/editors/dreambooth/dreambooth.py index 0ad2af3a8f..a711f16d8b 100644 --- a/mmagic/models/editors/dreambooth/dreambooth.py +++ b/mmagic/models/editors/dreambooth/dreambooth.py @@ -191,14 +191,8 @@ def val_step(self, data: dict) -> SampleList: data_samples.split() * len(prompt) data_samples = DataSample.stack(data_samples.split() * len(prompt)) - unet_dtype = next(self.unet.parameters()).dtype - self.unet.to(self.dtype) - output = self.infer(prompt, return_type='tensor') samples = output['samples'] - - self.unet.to(unet_dtype) - samples = self.data_preprocessor.destruct(samples, data_samples) out_data_sample = DataSample(fake_img=samples, prompt=prompt) @@ -226,14 +220,8 @@ def test_step(self, data: dict) -> SampleList: # construct a fake data_sample for destruct data_samples = DataSample.stack(data['data_samples'] * len(prompt)) - unet_dtype = next(self.unet.parameters()).dtype - self.unet.to(self.dtype) - output = self.infer(prompt, return_type='tensor') samples = output['samples'] - - self.unet.to(unet_dtype) - samples = self.data_preprocessor.destruct(samples, data_samples) out_data_sample = DataSample(fake_img=samples, prompt=prompt) diff --git a/mmagic/models/editors/stable_diffusion/stable_diffusion.py b/mmagic/models/editors/stable_diffusion/stable_diffusion.py index 38cbdff289..02cd956fcc 100644 --- a/mmagic/models/editors/stable_diffusion/stable_diffusion.py +++ b/mmagic/models/editors/stable_diffusion/stable_diffusion.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +from mmengine import print_log from mmengine.logging import MMLogger from mmengine.model import BaseModel from mmengine.runner import set_random_seed @@ -90,6 +91,8 @@ def __init__(self, self.vae = build_module(vae, MODELS, default_args=default_args) self.unet = build_module(unet, MODELS) # NOTE: initialize unet as fp32 + self._unet_ori_dtype = next(self.unet.parameters()).dtype + print_log(f'Set UNet dtype to \'{self._unet_ori_dtype}\'.', 'current') self.scheduler = build_module(scheduler, DIFFUSION_SCHEDULERS) if test_scheduler is None: self.test_scheduler = deepcopy(self.scheduler) @@ -140,6 +143,24 @@ def set_tomesd(self) -> nn.Module: def device(self): return next(self.parameters()).device + def train(self, mode: bool = True): + """Set train/eval mode. + + Args: + mode (bool, optional): Whether set train mode. Defaults to True. + """ + if mode: + if next(self.unet.parameters()).dtype != self._unet_ori_dtype: + print_log( + f'Set UNet dtype to \'{self._unet_ori_dtype}\' ' + 'in the train mode.', 'current') + self.unet.to(self._unet_ori_dtype) + else: + self.unet.to(self.dtype) + print_log(f'Set UNet dtype to \'{self.dtype}\' in the eval mode.', + 'current') + return super().train(mode) + @torch.no_grad() def infer(self, prompt: Union[str, List[str]],