Skip to content

Commit

Permalink
support noise offset in stable diffusion training
Browse files Browse the repository at this point in the history
  • Loading branch information
LeoXing1996 committed May 29, 2023
1 parent 47a54de commit 51f1096
Showing 1 changed file with 16 additions and 0 deletions.
16 changes: 16 additions & 0 deletions mmagic/models/editors/stable_diffusion/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class StableDiffusion(BaseModel):
when dtype is defined for submodels. Defaults to None.
enable_xformers (bool, optional): Whether to use xformers.
Defaults to True.
noise_offset_weight (bool, optional): The weight of noise offset
introduced in https://www.crosslabs.org/blog/diffusion-with-offset-noise
Defaults to 0.
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`.
init_cfg (dict, optional): The weight initialized config for
Expand All @@ -62,6 +65,7 @@ def __init__(self,
test_scheduler: Optional[ModelType] = None,
dtype: Optional[str] = None,
enable_xformers: bool = True,
noise_offset_weight: float = 0,
tomesd_cfg: Optional[dict] = None,
data_preprocessor: Optional[ModelType] = dict(
type='DataPreprocessor'),
Expand Down Expand Up @@ -92,6 +96,9 @@ def __init__(self,
self.unet_sample_size = self.unet.sample_size
self.vae_scale_factor = 2**(len(self.vae.block_out_channels) - 1)

self.enable_noise_offset = noise_offset_weight > 0
self.noise_offset_weight = noise_offset_weight

self.enable_xformers = enable_xformers
self.set_xformers()

Expand Down Expand Up @@ -602,6 +609,15 @@ def train_step(self, data, optim_wrapper_dict):
latents = latents * vae.config.scaling_factor

noise = torch.randn_like(latents)

if self.enable_noise_offset:
noise = noise + self.noise_offset_weight * torch.randn(
latents.shape[0],
latents.shape[1],
1,
1,
device=noise.device)

timesteps = torch.randint(
0,
self.scheduler.num_train_timesteps, (num_batches, ),
Expand Down

0 comments on commit 51f1096

Please sign in to comment.