Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Z-Fran committed Jun 26, 2023
1 parent 74275c9 commit 720c176
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 40 deletions.
2 changes: 1 addition & 1 deletion mmagic/apis/inferencers/controlnet_animation_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self,
self.inference_method = cfg.inference_method
if self.inference_method == 'attention_injection':
cfg.model.attention_injection = True
self.pipe = MODELS.build(cfg.model).cuda()
self.pipe = MODELS.build(cfg.model).cuda().eval()

control_scheduler_cfg = dict(
type=cfg.control_scheduler,
Expand Down
50 changes: 33 additions & 17 deletions mmagic/models/editors/controlnet/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def __init__(self,

# NOTE: initialize controlnet as fp32
self.controlnet = build_module(controlnet, MODELS)
self._controlnet_ori_dtype = next(self.controlnet.parameters()).dtype
print_log(
'Set ControlNetModel dtype to '
f'\'{self._controlnet_ori_dtype}\'.', 'current')
self.set_xformers(self.controlnet)

self.vae.requires_grad_(False)
Expand Down Expand Up @@ -272,18 +276,10 @@ def val_step(self, data: dict) -> SampleList:
prompt = data['data_samples'].prompt
control = data['inputs']['source']

unet_dtype = next(self.unet.parameters()).dtype
self.unet.to(self.dtype)
controlnet_dtype = next(self.controlnet.parameters()).dtype
self.controlnet.to(self.dtype)

output = self.infer(
prompt, control=((control + 1) / 2), return_type='tensor')
samples = output['samples']

self.unet.to(unet_dtype)
self.controlnet.to(controlnet_dtype)

samples = self.data_preprocessor.destruct(
samples, data['data_samples'], key='target')
control = self.data_preprocessor.destruct(
Expand Down Expand Up @@ -311,18 +307,10 @@ def test_step(self, data: dict) -> SampleList:
prompt = data['data_samples'].prompt
control = data['inputs']['source']

unet_dtype = next(self.unet.parameters()).dtype
self.unet.to(self.dtype)
controlnet_dtype = next(self.controlnet.parameters()).dtype
self.controlnet.to(self.dtype)

output = self.infer(
prompt, control=((control + 1) / 2), return_type='tensor')
samples = output['samples']

self.unet.to(unet_dtype)
self.controlnet.to(controlnet_dtype)

samples = self.data_preprocessor.destruct(
samples, data['data_samples'], key='target')
control = self.data_preprocessor.destruct(
Expand Down Expand Up @@ -387,6 +375,27 @@ def prepare_control(image: Tuple[Image.Image, List[Image.Image], Tensor,

return image

def train(self, mode: bool = True):
"""Set train/eval mode.
Args:
mode (bool, optional): Whether set train mode. Defaults to True.
"""
if mode:
if next(self.controlnet.parameters()
).dtype != self._controlnet_ori_dtype:
print_log(
'Set ControlNetModel dtype to '
f'\'{self._controlnet_ori_dtype}\' in the train mode.',
'current')
self.controlnet.to(self._controlnet_ori_dtype)
else:
self.controlnet.to(self.dtype)
print_log(
f'Set ControlNetModel dtype to \'{self.dtype}\' '
'in the eval mode.', 'current')
return super().train(mode)

@torch.no_grad()
def infer(self,
prompt: Union[str, List[str]],
Expand Down Expand Up @@ -791,6 +800,8 @@ def infer(
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0

img_dtype = self.vae.module.dtype if hasattr(
self.vae, 'module') else self.vae.dtype
if is_model_wrapper(self.controlnet):
control_dtype = self.controlnet.module.dtype
else:
Expand Down Expand Up @@ -818,6 +829,7 @@ def infer(
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt)
text_embeddings = text_embeddings.to(control_dtype)

# 4. Prepare timesteps
# self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down Expand Up @@ -866,6 +878,8 @@ def infer(
latent_model_input = self.test_scheduler.scale_model_input(
latent_model_input, t)

latent_model_input = latent_model_input.to(control_dtype)

if reference_img is not None:
ref_img_vae_latents_t = self.scheduler.add_noise(
ref_img_vae_latents, torch.randn_like(ref_img_vae_latents),
Expand All @@ -876,6 +890,8 @@ def infer(
ref_img_vae_latents_model_input = \
self.test_scheduler.scale_model_input(
ref_img_vae_latents_model_input, t)
ref_img_vae_latents_model_input = \
ref_img_vae_latents_model_input.to(control_dtype)

down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
Expand Down Expand Up @@ -924,7 +940,7 @@ def infer(
vae_encode_latents * (1.0 - latent_mask)

# 8. Post-processing
image = self.decode_latents(latents)
image = self.decode_latents(latents.to(img_dtype))

if do_classifier_free_guidance:
controls = torch.split(controls, controls.shape[0] // 2, dim=0)[0]
Expand Down
12 changes: 0 additions & 12 deletions mmagic/models/editors/dreambooth/dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,8 @@ def val_step(self, data: dict) -> SampleList:
data_samples.split() * len(prompt)
data_samples = DataSample.stack(data_samples.split() * len(prompt))

unet_dtype = next(self.unet.parameters()).dtype
self.unet.to(self.dtype)

output = self.infer(prompt, return_type='tensor')
samples = output['samples']

self.unet.to(unet_dtype)

samples = self.data_preprocessor.destruct(samples, data_samples)

out_data_sample = DataSample(fake_img=samples, prompt=prompt)
Expand Down Expand Up @@ -226,14 +220,8 @@ def test_step(self, data: dict) -> SampleList:
# construct a fake data_sample for destruct
data_samples = DataSample.stack(data['data_samples'] * len(prompt))

unet_dtype = next(self.unet.parameters()).dtype
self.unet.to(self.dtype)

output = self.infer(prompt, return_type='tensor')
samples = output['samples']

self.unet.to(unet_dtype)

samples = self.data_preprocessor.destruct(samples, data_samples)

out_data_sample = DataSample(fake_img=samples, prompt=prompt)
Expand Down
31 changes: 21 additions & 10 deletions mmagic/models/editors/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine import print_log
from mmengine.logging import MMLogger
from mmengine.model import BaseModel
from mmengine.runner import set_random_seed
Expand Down Expand Up @@ -90,6 +91,8 @@ def __init__(self,

self.vae = build_module(vae, MODELS, default_args=default_args)
self.unet = build_module(unet, MODELS) # NOTE: initialize unet as fp32
self._unet_ori_dtype = next(self.unet.parameters()).dtype
print_log(f'Set UNet dtype to \'{self._unet_ori_dtype}\'.', 'current')
self.scheduler = build_module(scheduler, DIFFUSION_SCHEDULERS)
if test_scheduler is None:
self.test_scheduler = deepcopy(self.scheduler)
Expand Down Expand Up @@ -140,6 +143,24 @@ def set_tomesd(self) -> nn.Module:
def device(self):
return next(self.parameters()).device

def train(self, mode: bool = True):
"""Set train/eval mode.
Args:
mode (bool, optional): Whether set train mode. Defaults to True.
"""
if mode:
if next(self.unet.parameters()).dtype != self._unet_ori_dtype:
print_log(
f'Set UNet dtype to \'{self._unet_ori_dtype}\' '
'in the train mode.', 'current')
self.unet.to(self._unet_ori_dtype)
else:
self.unet.to(self.dtype)
print_log(f'Set UNet dtype to \'{self.dtype}\' in the eval mode.',
'current')
return super().train(mode)

@torch.no_grad()
def infer(self,
prompt: Union[str, List[str]],
Expand Down Expand Up @@ -572,14 +593,9 @@ def val_step(self, data: dict) -> SampleList:
data_samples = data['data_samples']
prompt = data_samples.prompt

unet_dtype = next(self.unet.parameters()).dtype
self.unet.to(self.dtype)

output = self.infer(prompt, return_type='tensor')
samples = output['samples']

self.unet.to(unet_dtype)

samples = self.data_preprocessor.destruct(samples, data_samples)
gt_img = self.data_preprocessor.destruct(data['inputs'], data_samples)

Expand All @@ -596,14 +612,9 @@ def test_step(self, data: dict) -> SampleList:
data_samples = data['data_samples']
prompt = data_samples.prompt

unet_dtype = next(self.unet.parameters()).dtype
self.unet.to(self.dtype)

output = self.infer(prompt, return_type='tensor')
samples = output['samples']

self.unet.to(unet_dtype)

samples = self.data_preprocessor.destruct(samples, data_samples)
gt_img = self.data_preprocessor.destruct(data['inputs'], data_samples)

Expand Down

0 comments on commit 720c176

Please sign in to comment.