Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Z-Fran committed Jun 26, 2023
1 parent 47f44ff commit 74275c9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 47 deletions.
56 changes: 28 additions & 28 deletions mmagic/models/editors/controlnet/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def __init__(self,

# NOTE: initialize controlnet as fp32
self.controlnet = build_module(controlnet, MODELS)
self._controlnet_ori_dtype = next(self.controlnet.parameters()).dtype
self.set_xformers(self.controlnet)

self.vae.requires_grad_(False)
Expand Down Expand Up @@ -203,6 +202,7 @@ def train_step(self, data: dict,

num_batches = target.shape[0]

target = target.to(self.dtype)
latents = self.vae.encode(target).latent_dist.sample()
latents = latents * self.vae.config.scaling_factor

Expand Down Expand Up @@ -233,19 +233,20 @@ def train_step(self, data: dict,
f'{self.scheduler.config.prediction_type}')

# forward control
# NOTE: we train controlnet in fp32, convert to float manually
down_block_res_samples, mid_block_res_sample = self.controlnet(
noisy_latents,
noisy_latents.float(),
timesteps,
encoder_hidden_states=encoder_hidden_states,
controlnet_cond=control,
encoder_hidden_states=encoder_hidden_states.float(),
controlnet_cond=control.float(),
return_dict=False,
)

# Predict the noise residual and compute loss
# NOTE: we train unet in fp32, convert to float manually
model_output = self.unet(
noisy_latents,
noisy_latents.float(),
timesteps,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states=encoder_hidden_states.float(),
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample)
model_pred = model_output['sample']
Expand All @@ -270,10 +271,19 @@ def val_step(self, data: dict) -> SampleList:
data = self.data_preprocessor(data)
prompt = data['data_samples'].prompt
control = data['inputs']['source']

unet_dtype = next(self.unet.parameters()).dtype
self.unet.to(self.dtype)
controlnet_dtype = next(self.controlnet.parameters()).dtype
self.controlnet.to(self.dtype)

output = self.infer(
prompt, control=((control + 1) / 2), return_type='tensor')

samples = output['samples']

self.unet.to(unet_dtype)
self.controlnet.to(controlnet_dtype)

samples = self.data_preprocessor.destruct(
samples, data['data_samples'], key='target')
control = self.data_preprocessor.destruct(
Expand All @@ -300,10 +310,19 @@ def test_step(self, data: dict) -> SampleList:
data = self.data_preprocessor(data)
prompt = data['data_samples'].prompt
control = data['inputs']['source']

unet_dtype = next(self.unet.parameters()).dtype
self.unet.to(self.dtype)
controlnet_dtype = next(self.controlnet.parameters()).dtype
self.controlnet.to(self.dtype)

output = self.infer(
prompt, control=((control + 1) / 2), return_type='tensor')

samples = output['samples']

self.unet.to(unet_dtype)
self.controlnet.to(controlnet_dtype)

samples = self.data_preprocessor.destruct(
samples, data['data_samples'], key='target')
control = self.data_preprocessor.destruct(
Expand Down Expand Up @@ -368,25 +387,6 @@ def prepare_control(image: Tuple[Image.Image, List[Image.Image], Tensor,

return image

def train(self, mode: bool = True):
"""Set train/eval mode.
Args:
mode (bool, optional): Whether set train mode. Defaults to True.
"""
if mode:
self.controlnet.to(self._controlnet_ori_dtype)
print_log(
'Set ControlNetModel dtype to '
f'\'{self._controlnet_ori_dtype}\' in the train mode.',
'current')
else:
self.controlnet.to(self.dtype)
print_log(
f'Set ControlNetModel dtype to \'{self.dtype}\' '
'in the eval mode.', 'current')
return super().train(mode)

@torch.no_grad()
def infer(self,
prompt: Union[str, List[str]],
Expand Down
29 changes: 10 additions & 19 deletions mmagic/models/editors/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine import print_log
from mmengine.logging import MMLogger
from mmengine.model import BaseModel
from mmengine.runner import set_random_seed
Expand Down Expand Up @@ -91,7 +90,6 @@ def __init__(self,

self.vae = build_module(vae, MODELS, default_args=default_args)
self.unet = build_module(unet, MODELS) # NOTE: initialize unet as fp32
self._unet_ori_dtype = next(self.unet.parameters()).dtype
self.scheduler = build_module(scheduler, DIFFUSION_SCHEDULERS)
if test_scheduler is None:
self.test_scheduler = deepcopy(self.scheduler)
Expand Down Expand Up @@ -142,23 +140,6 @@ def set_tomesd(self) -> nn.Module:
def device(self):
return next(self.parameters()).device

def train(self, mode: bool = True):
"""Set train/eval mode.
Args:
mode (bool, optional): Whether set train mode. Defaults to True.
"""
if mode:
self.unet.to(self._unet_ori_dtype)
print_log(
f'Set UNet dtype to \'{self._unet_ori_dtype}\' '
'in the train mode.', 'current')
else:
self.unet.to(self.dtype)
print_log(f'Set UNet dtype to \'{self.dtype}\' in the eval mode.',
'current')
return super().train(mode)

@torch.no_grad()
def infer(self,
prompt: Union[str, List[str]],
Expand Down Expand Up @@ -591,9 +572,14 @@ def val_step(self, data: dict) -> SampleList:
data_samples = data['data_samples']
prompt = data_samples.prompt

unet_dtype = next(self.unet.parameters()).dtype
self.unet.to(self.dtype)

output = self.infer(prompt, return_type='tensor')
samples = output['samples']

self.unet.to(unet_dtype)

samples = self.data_preprocessor.destruct(samples, data_samples)
gt_img = self.data_preprocessor.destruct(data['inputs'], data_samples)

Expand All @@ -610,9 +596,14 @@ def test_step(self, data: dict) -> SampleList:
data_samples = data['data_samples']
prompt = data_samples.prompt

unet_dtype = next(self.unet.parameters()).dtype
self.unet.to(self.dtype)

output = self.infer(prompt, return_type='tensor')
samples = output['samples']

self.unet.to(unet_dtype)

samples = self.data_preprocessor.destruct(samples, data_samples)
gt_img = self.data_preprocessor.destruct(data['inputs'], data_samples)

Expand Down

0 comments on commit 74275c9

Please sign in to comment.