Skip to content

Commit

Permalink
Merge pull request #798 from vvern999/vvern999-patch-1
Browse files Browse the repository at this point in the history
add input perturbation noise
  • Loading branch information
kohya-ss authored Sep 3, 2023
2 parents f6d417e + e0beb6a commit 2eae9b6
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2894,6 +2894,12 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="enable multires noise with this number of iterations (if enabled, around 6-10 is recommended) / Multires noiseを有効にしてこのイテレーション数を設定する(有効にする場合は6-10程度を推奨)",
)
parser.add_argument(
"--ip_noise_gamma",
type=float,
default=None,
help="enable input perturbation noise. used for regularization. recommended value: around 0.1 (from arxiv.org/abs/2301.11706) / ",
)
# parser.add_argument(
# "--perlin_noise",
# type=int,
Expand Down Expand Up @@ -4347,9 +4353,12 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=latents.device)
timesteps = timesteps.long()

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
if args.ip_noise_gamma:
noisy_latents = noise_scheduler.add_noise(latents, noise + args.ip_noise_gamma * torch.randn_like(latents), timesteps)
else:
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

return noise, noisy_latents, timesteps

Expand Down

0 comments on commit 2eae9b6

Please sign in to comment.