Skip to content

Commit

Permalink
fix: zero terminal snr rescale betas before caching snr
Browse files Browse the repository at this point in the history
  • Loading branch information
feffy380 committed Dec 22, 2024
1 parent 8805d2d commit 8e041f4
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 8 deletions.
8 changes: 3 additions & 5 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ def prepare_scheduler_for_custom_training(noise_scheduler, device):
return

alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
alpha = sqrt_alphas_cumprod
sigma = sqrt_one_minus_alphas_cumprod
all_snr = (alpha / sigma) ** 2
all_snr = alphas_cumprod / (1.0 - alphas_cumprod)
# avoid division by zero
all_snr[-1] = max(all_snr[-1], 4.8973451890853435e-08)

noise_scheduler.all_snr = all_snr.to(device)

Expand Down
4 changes: 1 addition & 3 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,9 @@ def post_process_network(self, args, accelerator, network, text_encoders, unet):

def get_noise_scheduler(self, args: argparse.Namespace, device: torch.device) -> Any:
noise_scheduler = DDPMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, clip_sample=False, rescale_betas_zero_snr=args.zero_terminal_snr
)
prepare_scheduler_for_custom_training(noise_scheduler, device)
if args.zero_terminal_snr:
custom_train_functions.fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler)
return noise_scheduler

def encode_images_to_latents(self, args, accelerator, vae, images):
Expand Down

0 comments on commit 8e041f4

Please sign in to comment.