From ec3f2cc5e36704bb232257beab9d80fc8a0abf42 Mon Sep 17 00:00:00 2001 From: hlky Date: Tue, 24 Sep 2024 01:42:23 +0100 Subject: [PATCH] Add beta sigmas / beta noise schedule --- .../schedulers/scheduling_euler_discrete.py | 55 +++++++++++++++++-- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 2b74558d3cb75..142dff3f56cfb 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -17,10 +17,11 @@ from typing import List, Optional, Tuple, Union import numpy as np +import scipy.stats import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging +from ..utils import BaseOutput, is_scipy_available, logging from ..utils.torch_utils import randn_tensor from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin @@ -160,6 +161,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): the sigmas are determined according to a sequence of noise levels {σi}. use_exponential_sigmas (`bool`, *optional*, defaults to `False`): Whether to use exponential sigmas for step sizes in the noise schedule during the sampling process. + use_beta_sigmas (`bool`, *optional*, defaults to `False`): + Whether to use beta sigmas for step sizes in the noise schedule during the sampling process. Refer to + [Beta Sampling is All You Need](https://huggingface.co/papers/2407.12173) for more information. timestep_spacing (`str`, defaults to `"linspace"`): The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. @@ -189,6 +193,7 @@ def __init__( interpolation_type: str = "linear", use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, timestep_spacing: str = "linspace", @@ -197,6 +202,12 @@ def __init__( rescale_betas_zero_snr: bool = False, final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" ): + if self.config.use_beta_sigmas and not is_scipy_available(): + raise ImportError("Make sure to install scipy if you want to use beta sigmas.") + if sum([self.config.use_beta_sigmas, self.config.use_exponential_sigmas, self.config.use_karras_sigmas]) > 1: + raise ValueError( + "Only one of `config.use_beta_sigmas`, `config.use_exponential_sigmas`, `config.use_karras_sigmas` can be used." + ) if trained_betas is not None: self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": @@ -239,6 +250,7 @@ def __init__( self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas self.use_exponential_sigmas = use_exponential_sigmas + self.use_beta_sigmas = use_beta_sigmas self._step_index = None self._begin_index = None @@ -338,10 +350,8 @@ def set_timesteps( raise ValueError("Cannot set `timesteps` with `config.use_karras_sigmas = True`.") if timesteps is not None and self.config.use_exponential_sigmas: raise ValueError("Cannot set `timesteps` with `config.use_exponential_sigmas = True`.") - if self.config.use_exponential_sigmas and self.config.use_karras_sigmas: - raise ValueError( - "Cannot set both `config.use_exponential_sigmas = True` and config.use_karras_sigmas = True`" - ) + if timesteps is not None and self.config.use_beta_sigmas: + raise ValueError("Cannot set `timesteps` with `config.use_beta_sigmas = True`.") if ( timesteps is not None and self.config.timestep_type == "continuous" @@ -408,6 +418,10 @@ def set_timesteps( elif self.config.use_exponential_sigmas: sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) + timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) # + + elif self.config.use_beta_sigmas: + sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) if self.config.final_sigmas_type == "sigma_min": @@ -504,6 +518,37 @@ def _convert_to_exponential(self, in_sigmas: torch.Tensor, num_inference_steps: sigmas = torch.linspace(math.log(sigma_max), math.log(sigma_min), num_inference_steps).exp() return sigmas + def _convert_to_beta( + self, in_sigmas: torch.Tensor, num_inference_steps: int, alpha: float = 0.6, beta: float = 0.6 + ) -> torch.Tensor: + # From "Beta Sampling is All You Need" [arXiv:2407.12173] (Lee et. al, 2024) """ + + # Hack to make sure that other schedulers which copy this function don't break + # TODO: Add this logic to the other schedulers + if hasattr(self.config, "sigma_min"): + sigma_min = self.config.sigma_min + else: + sigma_min = None + + if hasattr(self.config, "sigma_max"): + sigma_max = self.config.sigma_max + else: + sigma_max = None + + sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item() + sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item() + + sigmas = torch.FloatTensor( + [ + sigma_min + (ppf * (sigma_max - sigma_min)) + for ppf in [ + scipy.stats.beta.ppf(timestep, alpha, beta) + for timestep in 1 - np.linspace(0, 1, num_inference_steps) + ] + ] + ) + return sigmas + def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps