Skip to content

Commit

Permalink
EDMEulerScheduler accept sigmas, add final_sigmas_type (huggingface#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky authored Feb 7, 2025
1 parent d43ce14 commit 464374f
Showing 1 changed file with 45 additions and 10 deletions.
55 changes: 45 additions & 10 deletions src/diffusers/schedulers/scheduling_edm_euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
Video](https://imagen.research.google/video/paper.pdf) paper).
rho (`float`, *optional*, defaults to 7.0):
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
"""

_compatibles = []
Expand All @@ -92,22 +95,32 @@ def __init__(
num_train_timesteps: int = 1000,
prediction_type: str = "epsilon",
rho: float = 7.0,
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
):
if sigma_schedule not in ["karras", "exponential"]:
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")

# setable values
self.num_inference_steps = None

ramp = torch.linspace(0, 1, num_train_timesteps)
sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
if sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
sigmas = self._compute_karras_sigmas(sigmas)
elif sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
sigmas = self._compute_exponential_sigmas(sigmas)

self.timesteps = self.precondition_noise(sigmas)

self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)

self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])

self.is_scale_input_called = False

Expand Down Expand Up @@ -197,7 +210,12 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
self.is_scale_input_called = True
return sample

def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Expand All @@ -206,19 +224,36 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
Custom sigmas to use for the denoising process. If not defined, the default behavior when
`num_inference_steps` is passed will be used.
"""
self.num_inference_steps = num_inference_steps

ramp = torch.linspace(0, 1, self.num_inference_steps)
if sigmas is None:
sigmas = torch.linspace(0, 1, self.num_inference_steps)
elif isinstance(sigmas, float):
sigmas = torch.tensor(sigmas, dtype=torch.float32)
else:
sigmas = sigmas
if self.config.sigma_schedule == "karras":
sigmas = self._compute_karras_sigmas(ramp)
sigmas = self._compute_karras_sigmas(sigmas)
elif self.config.sigma_schedule == "exponential":
sigmas = self._compute_exponential_sigmas(ramp)
sigmas = self._compute_exponential_sigmas(sigmas)

sigmas = sigmas.to(dtype=torch.float32, device=device)
self.timesteps = self.precondition_noise(sigmas)

self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if self.config.final_sigmas_type == "sigma_min":
sigma_last = sigmas[-1]
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
)

self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
Expand Down

0 comments on commit 464374f

Please sign in to comment.