Skip to content

Commit

Permalink
[Enhancement] Support noise offset in stable diffusion training (#1880)
Browse files Browse the repository at this point in the history
* support noise offset in stable diffusion training

* update arg list for dreambooth

* update arg list for control net
  • Loading branch information
LeoXing1996 authored May 29, 2023
1 parent 5f62dee commit 3c5bd83
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 4 deletions.
9 changes: 7 additions & 2 deletions mmagic/models/editors/controlnet/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class ControlStableDiffusion(StableDiffusion):
dtype (str, optional): The dtype for the model. Defaults to 'fp16'.
enable_xformers (bool, optional): Whether to use xformers.
Defaults to True.
noise_offset_weight (bool, optional): The weight of noise offset
introduced in https://www.crosslabs.org/blog/diffusion-with-offset-noise # noqa
Defaults to 0.
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`. Defaults to
dict(type='DataPreprocessor').
Expand All @@ -63,12 +66,14 @@ def __init__(self,
test_scheduler: Optional[ModelType] = None,
dtype: str = 'fp32',
enable_xformers: bool = True,
noise_offset_weight: float = 0,
tomesd_cfg: Optional[dict] = None,
data_preprocessor=dict(type='DataPreprocessor'),
init_cfg: Optional[dict] = None):
super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
test_scheduler, dtype, enable_xformers, tomesd_cfg,
data_preprocessor, init_cfg)
test_scheduler, dtype, enable_xformers,
noise_offset_weight, tomesd_cfg, data_preprocessor,
init_cfg)

default_args = dict()
if dtype is not None:
Expand Down
9 changes: 7 additions & 2 deletions mmagic/models/editors/dreambooth/dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class DreamBooth(StableDiffusion):
dtype (str, optional): The dtype for the model. Defaults to 'fp16'.
enable_xformers (bool, optional): Whether to use xformers.
Defaults to True.
noise_offset_weight (bool, optional): The weight of noise offset
introduced in https://www.crosslabs.org/blog/diffusion-with-offset-noise # noqa
Defaults to 0.
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`. Defaults to
dict(type='DataPreprocessor').
Expand All @@ -73,14 +76,16 @@ def __init__(self,
finetune_text_encoder: bool = False,
dtype: str = 'fp16',
enable_xformers: bool = True,
noise_offset_weight: float = 0,
tomesd_cfg: Optional[dict] = None,
data_preprocessor: Optional[ModelType] = dict(
type='DataPreprocessor'),
init_cfg: Optional[dict] = None):

super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
test_scheduler, dtype, enable_xformers, tomesd_cfg,
data_preprocessor, init_cfg)
test_scheduler, dtype, enable_xformers,
noise_offset_weight, tomesd_cfg, data_preprocessor,
init_cfg)
self.num_class_images = num_class_images
self.class_prior_prompt = class_prior_prompt
self.prior_loss_weight = prior_loss_weight
Expand Down
16 changes: 16 additions & 0 deletions mmagic/models/editors/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class StableDiffusion(BaseModel):
when dtype is defined for submodels. Defaults to None.
enable_xformers (bool, optional): Whether to use xformers.
Defaults to True.
noise_offset_weight (bool, optional): The weight of noise offset
introduced in https://www.crosslabs.org/blog/diffusion-with-offset-noise
Defaults to 0.
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`.
init_cfg (dict, optional): The weight initialized config for
Expand All @@ -62,6 +65,7 @@ def __init__(self,
test_scheduler: Optional[ModelType] = None,
dtype: Optional[str] = None,
enable_xformers: bool = True,
noise_offset_weight: float = 0,
tomesd_cfg: Optional[dict] = None,
data_preprocessor: Optional[ModelType] = dict(
type='DataPreprocessor'),
Expand Down Expand Up @@ -102,6 +106,9 @@ def __init__(self,
self.unet_sample_size = self.unet.sample_size
self.vae_scale_factor = 2**(len(self.vae.block_out_channels) - 1)

self.enable_noise_offset = noise_offset_weight > 0
self.noise_offset_weight = noise_offset_weight

self.enable_xformers = enable_xformers
self.set_xformers()

Expand Down Expand Up @@ -612,6 +619,15 @@ def train_step(self, data, optim_wrapper_dict):
latents = latents * vae.config.scaling_factor

noise = torch.randn_like(latents)

if self.enable_noise_offset:
noise = noise + self.noise_offset_weight * torch.randn(
latents.shape[0],
latents.shape[1],
1,
1,
device=noise.device)

timesteps = torch.randint(
0,
self.scheduler.num_train_timesteps, (num_batches, ),
Expand Down

0 comments on commit 3c5bd83

Please sign in to comment.