From 4ede43b86a9bcdb6db2a7507a50f3733aa131e97 Mon Sep 17 00:00:00 2001 From: hlky Date: Sat, 23 Nov 2024 12:34:53 +0000 Subject: [PATCH] Combine Flow Match Euler into Euler --- .../pipeline_flux_differential_img2img.py | 4 +- ...stable_diffusion_3_differential_img2img.py | 2 +- ...pipeline_flux_controlnet_image_to_image.py | 2 +- .../pipeline_flux_controlnet_inpainting.py | 4 +- .../pipelines/flux/pipeline_flux_img2img.py | 2 +- .../pipelines/flux/pipeline_flux_inpaint.py | 4 +- .../pipeline_stable_diffusion_3_img2img.py | 2 +- .../pipeline_stable_diffusion_3_inpaint.py | 4 +- .../schedulers/scheduling_euler_discrete.py | 157 +++++++--- .../scheduling_flow_match_euler_discrete.py | 276 +----------------- 10 files changed, 135 insertions(+), 322 deletions(-) diff --git a/examples/community/pipeline_flux_differential_img2img.py b/examples/community/pipeline_flux_differential_img2img.py index 68cb69115bde3..7b9e5df5a4112 100644 --- a/examples/community/pipeline_flux_differential_img2img.py +++ b/examples/community/pipeline_flux_differential_img2img.py @@ -582,7 +582,7 @@ def prepare_latents( if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self.scheduler.add_noise(image_latents, timestep, noise) else: noise = latents.to(device) latents = noise @@ -976,7 +976,7 @@ def __call__( if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] - image_latent = self.scheduler.scale_noise( + image_latent = self.scheduler.add_noise( original_image_latents, torch.tensor([noise_timestep]), noise ) diff --git a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py index 8cee5ecbc1411..ee3a4ab1aebdf 100644 --- a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py +++ b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py @@ -640,7 +640,7 @@ def prepare_latents( shape = init_latents.shape noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - init_latents = self.scheduler.scale_noise(init_latents, timestep, noise) + init_latents = self.scheduler.add_noise(init_latents, timestep, noise) latents = init_latents.to(device=device, dtype=dtype) return latents diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 6ab34d8a9c082..246e2e9c5607d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -579,7 +579,7 @@ def prepare_latents( image_latents = torch.cat([image_latents], dim=0) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self.scheduler.add_noise(image_latents, timestep, noise) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, latent_image_ids diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index d81cffaca35b6..a5fc296820df1 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -605,7 +605,7 @@ def prepare_latents( if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self.scheduler.add_noise(image_latents, timestep, noise) else: noise = latents.to(device) latents = noise @@ -1154,7 +1154,7 @@ def __call__( if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( + init_latents_proper = self.scheduler.add_noise( init_latents_proper, torch.tensor([noise_timestep]), noise ) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index d34d9b53aa6bc..dd31f305f5c3a 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -562,7 +562,7 @@ def prepare_latents( image_latents = torch.cat([image_latents], dim=0) noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self.scheduler.add_noise(image_latents, timestep, noise) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) return latents, latent_image_ids diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 3fcf6ace8a791..f31cc17896d0f 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -582,7 +582,7 @@ def prepare_latents( if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.scale_noise(image_latents, timestep, noise) + latents = self.scheduler.add_noise(image_latents, timestep, noise) else: noise = latents.to(device) latents = noise @@ -978,7 +978,7 @@ def __call__( if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( + init_latents_proper = self.scheduler.add_noise( init_latents_proper, torch.tensor([noise_timestep]), noise ) diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index a07a056ec851f..de0408e0a65f2 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -671,7 +671,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # get latents - init_latents = self.scheduler.scale_noise(init_latents, timestep, noise) + init_latents = self.scheduler.add_noise(init_latents, timestep, noise) latents = init_latents.to(device=device, dtype=dtype) return latents diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index d3e0ecf9c3a74..0074029f1bbf9 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -680,7 +680,7 @@ def prepare_latents( if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) # if strength is 1. then initialise the latents to noise, else initial to image + noise - latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise) + latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, timestep, noise) else: noise = latents.to(device) latents = noise @@ -1145,7 +1145,7 @@ def __call__( if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] - init_latents_proper = self.scheduler.scale_noise( + init_latents_proper = self.scheduler.add_noise( init_latents_proper, torch.tensor([noise_timestep]), noise ) diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 56757f3ca1977..6bd7e9b2a315b 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -196,6 +196,7 @@ def __init__( use_karras_sigmas: Optional[bool] = False, use_exponential_sigmas: Optional[bool] = False, use_beta_sigmas: Optional[bool] = False, + use_flow_match: Optional[bool] = False, sigma_min: Optional[float] = None, sigma_max: Optional[float] = None, timestep_spacing: str = "linspace", @@ -203,6 +204,13 @@ def __init__( steps_offset: int = 0, rescale_betas_zero_snr: bool = False, final_sigmas_type: str = "zero", # can be "zero" or "sigma_min" + shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, ): if self.config.use_beta_sigmas and not is_scipy_available(): raise ImportError("Make sure to install scipy if you want to use beta sigmas.") @@ -234,20 +242,39 @@ def __init__( # FP16 smallest positive subnormal works well here self.alphas_cumprod[-1] = 2**-24 - sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0) - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + if use_flow_match: + timestep_offset = 1 + else: + timestep_offset = 0 + + timesteps = np.linspace( + 0 + timestep_offset, num_train_timesteps - 1 + timestep_offset, num_train_timesteps, dtype=float + )[::-1].copy() timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) + if use_flow_match: + sigmas = timesteps / num_train_timesteps + if not use_dynamic_shifting: + # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution + sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) + else: + sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0) + # setable values self.num_inference_steps = None # TODO: Support the full EDM scalings for all prediction types and timestep types if timestep_type == "continuous" and prediction_type == "v_prediction": self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas]) + elif use_flow_match: + self.timesteps = sigmas * num_train_timesteps else: self.timesteps = timesteps - self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + if not use_flow_match: + sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + self.sigmas = sigmas self.is_scale_input_called = False self.use_karras_sigmas = use_karras_sigmas @@ -257,6 +284,8 @@ def __init__( self._step_index = None self._begin_index = None self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + self.sigma_min = self.sigmas[-1].item() + self.sigma_max = self.sigmas[0].item() @property def init_noise_sigma(self): @@ -322,6 +351,7 @@ def set_timesteps( device: Union[str, torch.device] = None, timesteps: Optional[List[int]] = None, sigmas: Optional[List[float]] = None, + mu: Optional[float] = None, ): """ Sets the discrete timesteps used for the diffusion chain (to be run before inference). @@ -362,57 +392,81 @@ def set_timesteps( raise ValueError( "Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`." ) + if timesteps is not None and self.config.use_flow_match: + # TODO: `timesteps / self.config.num_train_timesteps` to get sigmas? + raise ValueError("Cannot set `timesteps` with `config.use_flow_match = True`.") + + if self.config.use_dynamic_shifting and mu is None: + raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") if num_inference_steps is None: num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1 self.num_inference_steps = num_inference_steps - if sigmas is not None: + if sigmas is not None and not self.config.use_flow_match: log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)) sigmas = np.array(sigmas).astype(np.float32) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]]) - - else: + elif sigmas is None: if timesteps is not None: timesteps = np.array(timesteps).astype(np.float32) else: - # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 - if self.config.timestep_spacing == "linspace": + if self.config.use_flow_match: timesteps = np.linspace( - 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32 - )[::-1].copy() - elif self.config.timestep_spacing == "leading": - step_ratio = self.config.num_train_timesteps // self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) - ) - timesteps += self.config.steps_offset - elif self.config.timestep_spacing == "trailing": - step_ratio = self.config.num_train_timesteps / self.num_inference_steps - # creates integer timesteps by multiplying by ratio - # casting to int to avoid issues when num_inference_step is power of 3 - timesteps = ( - (np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32) + self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps ) - timesteps -= 1 + sigmas = timesteps / self.config.num_train_timesteps else: - raise ValueError( - f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." - ) + # "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891 + if self.config.timestep_spacing == "linspace": + timesteps = np.linspace( + 0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32 + )[::-1].copy() + elif self.config.timestep_spacing == "leading": + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32) + ) + timesteps += self.config.steps_offset + elif self.config.timestep_spacing == "trailing": + step_ratio = self.config.num_train_timesteps / self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = ( + (np.arange(self.config.num_train_timesteps, 0, -step_ratio)) + .round() + .copy() + .astype(np.float32) + ) + timesteps -= 1 + else: + raise ValueError( + f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'." + ) + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + if self.config.interpolation_type == "linear": + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + elif self.config.interpolation_type == "log_linear": + sigmas = ( + torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1) + .exp() + .numpy() + ) + else: + raise ValueError( + f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" + " 'linear' or 'log_linear'" + ) + + if self.config.use_flow_match: + if self.config.use_dynamic_shifting: + sigmas = self.time_shift(mu, 1.0, sigmas) + else: + sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) log_sigmas = np.log(sigmas) - if self.config.interpolation_type == "linear": - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - elif self.config.interpolation_type == "log_linear": - sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy() - else: - raise ValueError( - f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either" - " 'linear' or 'log_linear'" - ) if self.config.use_karras_sigmas: sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps) @@ -426,10 +480,16 @@ def set_timesteps( sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps) timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]) + if self.config.invert_sigmas: + sigmas = 1.0 - sigmas + timesteps = sigmas * self.config.num_train_timesteps + if self.config.final_sigmas_type == "sigma_min": sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5 elif self.config.final_sigmas_type == "zero": sigma_last = 0 + elif self.config.invert_sigmas: + sigma_last = 1 else: raise ValueError( f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" @@ -442,6 +502,8 @@ def set_timesteps( # TODO: Support the full EDM scalings for all prediction types and timestep types if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction": self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device) + elif self.config.use_flow_match: + self.timesteps = sigmas * self.config.num_train_timesteps else: self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device) @@ -449,7 +511,12 @@ def set_timesteps( self._begin_index = None self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication - def _sigma_to_t(self, sigma, log_sigmas): + def time_shift(self, mu: float, sigma: float, t: torch.Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + def _sigma_to_t(self, sigma, log_sigmas=None): + if self.config.use_flow_match: + return sigma * self.config.num_train_timesteps # get log sigma log_sigma = np.log(np.maximum(sigma, 1e-10)) @@ -622,7 +689,7 @@ def step( ), ) - if not self.is_scale_input_called: + if not self.is_scale_input_called and not self.config.use_flow_match: logger.warning( "The `scale_model_input` function should be called before `step` to ensure correct denoising. " "See `StableDiffusionPipeline` for a usage example." @@ -663,7 +730,10 @@ def step( ) # 2. Convert to an ODE derivative - derivative = (sample - pred_original_sample) / sigma_hat + if self.config.use_flow_match: + derivative = model_output + else: + derivative = (sample - pred_original_sample) / sigma_hat dt = self.sigmas[self.step_index + 1] - sigma_hat @@ -713,7 +783,10 @@ def add_noise( while len(sigma.shape) < len(original_samples.shape): sigma = sigma.unsqueeze(-1) - noisy_samples = original_samples + noise * sigma + if self.config.use_flow_match: + noisy_samples = (1.0 - sigma) * original_samples + noise * sigma + else: + noisy_samples = original_samples + noise * sigma return noisy_samples def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor: diff --git a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py index c1096dbe0c29f..e39022d5d28a7 100644 --- a/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py @@ -12,36 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math -from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch +from typing import Optional from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging +from ..utils import logging +from . import EulerDiscreteScheduler from .scheduling_utils import SchedulerMixin logger = logging.get_logger(__name__) # pylint: disable=invalid-name -@dataclass -class FlowMatchEulerDiscreteSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's `step` function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - """ - - prev_sample: torch.FloatTensor - - -class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): +class FlowMatchEulerDiscreteScheduler(EulerDiscreteScheduler, SchedulerMixin, ConfigMixin): """ Euler scheduler. @@ -58,9 +40,6 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin): The shift value for the timestep schedule. """ - _compatibles = [] - order = 1 - @register_to_config def __init__( self, @@ -72,247 +51,8 @@ def __init__( base_image_seq_len: Optional[int] = 256, max_image_seq_len: Optional[int] = 4096, invert_sigmas: bool = False, + use_karras_sigmas: Optional[bool] = False, + use_exponential_sigmas: Optional[bool] = False, + use_beta_sigmas: Optional[bool] = False, ): - timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy() - timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32) - - sigmas = timesteps / num_train_timesteps - if not use_dynamic_shifting: - # when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution - sigmas = shift * sigmas / (1 + (shift - 1) * sigmas) - - self.timesteps = sigmas * num_train_timesteps - - self._step_index = None - self._begin_index = None - - self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication - self.sigma_min = self.sigmas[-1].item() - self.sigma_max = self.sigmas[0].item() - - @property - def step_index(self): - """ - The index counter for current timestep. It will increase 1 after each scheduler step. - """ - return self._step_index - - @property - def begin_index(self): - """ - The index for the first timestep. It should be set from pipeline with `set_begin_index` method. - """ - return self._begin_index - - # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index - def set_begin_index(self, begin_index: int = 0): - """ - Sets the begin index for the scheduler. This function should be run from pipeline before the inference. - - Args: - begin_index (`int`): - The begin index for the scheduler. - """ - self._begin_index = begin_index - - def scale_noise( - self, - sample: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - noise: Optional[torch.FloatTensor] = None, - ) -> torch.FloatTensor: - """ - Forward process in flow-matching - - Args: - sample (`torch.FloatTensor`): - The input sample. - timestep (`int`, *optional*): - The current timestep in the diffusion chain. - - Returns: - `torch.FloatTensor`: - A scaled input sample. - """ - # Make sure sigmas and timesteps have the same device and dtype as original_samples - sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype) - - if sample.device.type == "mps" and torch.is_floating_point(timestep): - # mps does not support float64 - schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32) - timestep = timestep.to(sample.device, dtype=torch.float32) - else: - schedule_timesteps = self.timesteps.to(sample.device) - timestep = timestep.to(sample.device) - - # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index - if self.begin_index is None: - step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep] - elif self.step_index is not None: - # add_noise is called after first denoising step (for inpainting) - step_indices = [self.step_index] * timestep.shape[0] - else: - # add noise is called before first denoising step to create initial latent(img2img) - step_indices = [self.begin_index] * timestep.shape[0] - - sigma = sigmas[step_indices].flatten() - while len(sigma.shape) < len(sample.shape): - sigma = sigma.unsqueeze(-1) - - sample = sigma * noise + (1.0 - sigma) * sample - - return sample - - def _sigma_to_t(self, sigma): - return sigma * self.config.num_train_timesteps - - def time_shift(self, mu: float, sigma: float, t: torch.Tensor): - return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) - - def set_timesteps( - self, - num_inference_steps: int = None, - device: Union[str, torch.device] = None, - sigmas: Optional[List[float]] = None, - mu: Optional[float] = None, - ): - """ - Sets the discrete timesteps used for the diffusion chain (to be run before inference). - - Args: - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - - if self.config.use_dynamic_shifting and mu is None: - raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`") - - if sigmas is None: - self.num_inference_steps = num_inference_steps - timesteps = np.linspace( - self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps - ) - - sigmas = timesteps / self.config.num_train_timesteps - - if self.config.use_dynamic_shifting: - sigmas = self.time_shift(mu, 1.0, sigmas) - else: - sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas) - - sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device) - timesteps = sigmas * self.config.num_train_timesteps - - if self.config.invert_sigmas: - sigmas = 1.0 - sigmas - timesteps = sigmas * self.config.num_train_timesteps - sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)]) - else: - sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) - - self.timesteps = timesteps.to(device=device) - self.sigmas = sigmas - self._step_index = None - self._begin_index = None - - def index_for_timestep(self, timestep, schedule_timesteps=None): - if schedule_timesteps is None: - schedule_timesteps = self.timesteps - - indices = (schedule_timesteps == timestep).nonzero() - - # The sigma index that is taken for the **very** first `step` - # is always the second index (or the last index if there is only 1) - # This way we can ensure we don't accidentally skip a sigma in - # case we start in the middle of the denoising schedule (e.g. for image-to-image) - pos = 1 if len(indices) > 1 else 0 - - return indices[pos].item() - - def _init_step_index(self, timestep): - if self.begin_index is None: - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - self._step_index = self.index_for_timestep(timestep) - else: - self._step_index = self._begin_index - - def step( - self, - model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - sample: torch.FloatTensor, - s_churn: float = 0.0, - s_tmin: float = 0.0, - s_tmax: float = float("inf"), - s_noise: float = 1.0, - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[FlowMatchEulerDiscreteSchedulerOutput, Tuple]: - """ - Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): - The direct output from learned diffusion model. - timestep (`float`): - The current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - A current instance of a sample created by the diffusion process. - s_churn (`float`): - s_tmin (`float`): - s_tmax (`float`): - s_noise (`float`, defaults to 1.0): - Scaling factor for noise added to the sample. - generator (`torch.Generator`, *optional*): - A random number generator. - return_dict (`bool`): - Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or - tuple. - - Returns: - [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: - If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is - returned, otherwise a tuple is returned where the first element is the sample tensor. - """ - - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): - raise ValueError( - ( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep." - ), - ) - - if self.step_index is None: - self._init_step_index(timestep) - - # Upcast to avoid precision issues when computing prev_sample - sample = sample.to(torch.float32) - - sigma = self.sigmas[self.step_index] - sigma_next = self.sigmas[self.step_index + 1] - - prev_sample = sample + (sigma_next - sigma) * model_output - - # Cast sample back to model compatible dtype - prev_sample = prev_sample.to(model_output.dtype) - - # upon completion increase step index by one - self._step_index += 1 - - if not return_dict: - return (prev_sample,) - - return FlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) - - def __len__(self): - return self.config.num_train_timesteps + super().__init__(**self.config, use_flow_match=True)