Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix sd and controlnet fp16 bugs #1914

Merged
merged 3 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmagic/apis/inferencers/controlnet_animation_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self,
self.inference_method = cfg.inference_method
if self.inference_method == 'attention_injection':
cfg.model.attention_injection = True
self.pipe = MODELS.build(cfg.model).cuda()
self.pipe = MODELS.build(cfg.model).cuda().eval()

control_scheduler_cfg = dict(
type=cfg.control_scheduler,
Expand Down
66 changes: 54 additions & 12 deletions mmagic/models/editors/controlnet/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,13 @@ def __init__(self,
default_args = dict()
if dtype is not None:
default_args['dtype'] = dtype
self.controlnet = build_module(
controlnet, MODELS, default_args=default_args)

# NOTE: initialize controlnet as fp32
self.controlnet = build_module(controlnet, MODELS)
self._controlnet_ori_dtype = next(self.controlnet.parameters()).dtype
print_log(
'Set ControlNetModel dtype to '
f'\'{self._controlnet_ori_dtype}\'.', 'current')
self.set_xformers(self.controlnet)

self.vae.requires_grad_(False)
Expand Down Expand Up @@ -201,6 +206,7 @@ def train_step(self, data: dict,

num_batches = target.shape[0]

target = target.to(self.dtype)
latents = self.vae.encode(target).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor

Expand Down Expand Up @@ -231,19 +237,20 @@ def train_step(self, data: dict,
f'{self.scheduler.config.prediction_type}')

# forward control
# NOTE: we train controlnet in fp32, convert to float manually
down_block_res_samples, mid_block_res_sample = self.controlnet(
noisy_latents,
noisy_latents.float(),
timesteps,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control,
encoder_hidden_states=encoder_hidden_states.float(),
controlnet_cond=control.float(),
return_dict=False,
)

# Predict the noise residual and compute loss
# NOTE: we train unet in fp32, convert to float manually
model_output = self.unet(
noisy_latents,
noisy_latents.float(),
timesteps,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states=encoder_hidden_states.float(),
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample)
model_pred = model_output['sample']
Expand All @@ -268,10 +275,11 @@ def val_step(self, data: dict) -> SampleList:
data = self.data_preprocessor(data)
prompt = data['data_samples'].prompt
control = data['inputs']['source']

output = self.infer(
prompt, control=((control + 1) / 2), return_type='tensor')

samples = output['samples']

samples = self.data_preprocessor.destruct(
samples, data['data_samples'], key='target')
control = self.data_preprocessor.destruct(
Expand All @@ -298,10 +306,11 @@ def test_step(self, data: dict) -> SampleList:
data = self.data_preprocessor(data)
prompt = data['data_samples'].prompt
control = data['inputs']['source']

output = self.infer(
prompt, control=((control + 1) / 2), return_type='tensor')

samples = output['samples']

samples = self.data_preprocessor.destruct(
samples, data['data_samples'], key='target')
control = self.data_preprocessor.destruct(
Expand Down Expand Up @@ -366,6 +375,27 @@ def prepare_control(image: Tuple[Image.Image, List[Image.Image], Tensor,

return image

def train(self, mode: bool = True):
"""Set train/eval mode.

Args:
mode (bool, optional): Whether set train mode. Defaults to True.
"""
if mode:
if next(self.controlnet.parameters()
).dtype != self._controlnet_ori_dtype:
print_log(
'Set ControlNetModel dtype to '
f'\'{self._controlnet_ori_dtype}\' in the train mode.',
'current')
self.controlnet.to(self._controlnet_ori_dtype)
else:
self.controlnet.to(self.dtype)
print_log(
f'Set ControlNetModel dtype to \'{self.dtype}\' '
'in the eval mode.', 'current')
return super().train(mode)

@torch.no_grad()
def infer(self,
prompt: Union[str, List[str]],
Expand Down Expand Up @@ -448,6 +478,8 @@ def infer(self,
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0

img_dtype = self.vae.module.dtype if hasattr(self.vae, 'module') \
else self.vae.dtype
if is_model_wrapper(self.controlnet):
control_dtype = self.controlnet.module.dtype
else:
Expand Down Expand Up @@ -500,6 +532,9 @@ def infer(self,
latent_model_input = self.test_scheduler.scale_model_input(
latent_model_input, t)

latent_model_input = latent_model_input.to(control_dtype)
text_embeddings = text_embeddings.to(control_dtype)

down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
Expand Down Expand Up @@ -534,7 +569,7 @@ def infer(self,
noise_pred, t, latents, **extra_step_kwargs)['prev_sample']

# 8. Post-processing
image = self.decode_latents(latents)
image = self.decode_latents(latents.to(img_dtype))

if do_classifier_free_guidance:
controls = torch.split(controls, controls.shape[0] // 2, dim=0)[0]
Expand Down Expand Up @@ -765,6 +800,8 @@ def infer(
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0

img_dtype = self.vae.module.dtype if hasattr(
self.vae, 'module') else self.vae.dtype
if is_model_wrapper(self.controlnet):
control_dtype = self.controlnet.module.dtype
else:
Expand Down Expand Up @@ -792,6 +829,7 @@ def infer(
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt)
text_embeddings = text_embeddings.to(control_dtype)

# 4. Prepare timesteps
# self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down Expand Up @@ -840,6 +878,8 @@ def infer(
latent_model_input = self.test_scheduler.scale_model_input(
latent_model_input, t)

latent_model_input = latent_model_input.to(control_dtype)

if reference_img is not None:
ref_img_vae_latents_t = self.scheduler.add_noise(
ref_img_vae_latents, torch.randn_like(ref_img_vae_latents),
Expand All @@ -850,6 +890,8 @@ def infer(
ref_img_vae_latents_model_input = \
self.test_scheduler.scale_model_input(
ref_img_vae_latents_model_input, t)
ref_img_vae_latents_model_input = \
ref_img_vae_latents_model_input.to(control_dtype)

down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
Expand Down Expand Up @@ -898,7 +940,7 @@ def infer(
vae_encode_latents * (1.0 - latent_mask)

# 8. Post-processing
image = self.decode_latents(latents)
image = self.decode_latents(latents.to(img_dtype))

if do_classifier_free_guidance:
controls = torch.split(controls, controls.shape[0] // 2, dim=0)[0]
Expand Down
12 changes: 0 additions & 12 deletions mmagic/models/editors/dreambooth/dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,8 @@ def val_step(self, data: dict) -> SampleList:
data_samples.split() * len(prompt)
data_samples = DataSample.stack(data_samples.split() * len(prompt))

unet_dtype = next(self.unet.parameters()).dtype
self.unet.to(self.dtype)

output = self.infer(prompt, return_type='tensor')
samples = output['samples']

self.unet.to(unet_dtype)

samples = self.data_preprocessor.destruct(samples, data_samples)

out_data_sample = DataSample(fake_img=samples, prompt=prompt)
Expand Down Expand Up @@ -226,14 +220,8 @@ def test_step(self, data: dict) -> SampleList:
# construct a fake data_sample for destruct
data_samples = DataSample.stack(data['data_samples'] * len(prompt))

unet_dtype = next(self.unet.parameters()).dtype
self.unet.to(self.dtype)

output = self.infer(prompt, return_type='tensor')
samples = output['samples']

self.unet.to(unet_dtype)

samples = self.data_preprocessor.destruct(samples, data_samples)

out_data_sample = DataSample(fake_img=samples, prompt=prompt)
Expand Down
21 changes: 21 additions & 0 deletions mmagic/models/editors/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine import print_log
from mmengine.logging import MMLogger
from mmengine.model import BaseModel
from mmengine.runner import set_random_seed
Expand Down Expand Up @@ -90,6 +91,8 @@ def __init__(self,

self.vae = build_module(vae, MODELS, default_args=default_args)
self.unet = build_module(unet, MODELS) # NOTE: initialize unet as fp32
self._unet_ori_dtype = next(self.unet.parameters()).dtype
print_log(f'Set UNet dtype to \'{self._unet_ori_dtype}\'.', 'current')
self.scheduler = build_module(scheduler, DIFFUSION_SCHEDULERS)
if test_scheduler is None:
self.test_scheduler = deepcopy(self.scheduler)
Expand Down Expand Up @@ -140,6 +143,24 @@ def set_tomesd(self) -> nn.Module:
def device(self):
return next(self.parameters()).device

def train(self, mode: bool = True):
LeoXing1996 marked this conversation as resolved.
Show resolved Hide resolved
"""Set train/eval mode.

Args:
mode (bool, optional): Whether set train mode. Defaults to True.
"""
if mode:
if next(self.unet.parameters()).dtype != self._unet_ori_dtype:
print_log(
f'Set UNet dtype to \'{self._unet_ori_dtype}\' '
'in the train mode.', 'current')
self.unet.to(self._unet_ori_dtype)
else:
self.unet.to(self.dtype)
print_log(f'Set UNet dtype to \'{self.dtype}\' in the eval mode.',
'current')
return super().train(mode)

@torch.no_grad()
def infer(self,
prompt: Union[str, List[str]],
Expand Down