From e0beb6a9991c09ad8df87a37efe0b6b89a2b225e Mon Sep 17 00:00:00 2001 From: vvern999 <143861044+vvern999@users.noreply.github.com> Date: Sat, 2 Sep 2023 07:33:27 +0300 Subject: [PATCH] add input perturbation noise from https://arxiv.org/abs/2301.11706 --- library/train_util.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index ae1221314..0bb64e689 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -2865,6 +2865,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: default=None, help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)", ) + parser.add_argument( + "--ip_noise_gamma", + type=float, + default=None, + help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) / ", + ) # parser.add_argument( # "--perlin_noise", # type=int, @@ -4306,9 +4312,12 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents): timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device) timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + if args.ip_noise_gamma: + noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps) + else: + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) return noise, noisy_latents, timesteps