Skip to content

Commit

Permalink
Combine Flow Match Euler into Euler
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Nov 23, 2024
1 parent b5fd6f1 commit 4ede43b
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 322 deletions.
4 changes: 2 additions & 2 deletions examples/community/pipeline_flux_differential_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def prepare_latents(

if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self.scheduler.add_noise(image_latents, timestep, noise)
else:
noise = latents.to(device)
latents = noise
Expand Down Expand Up @@ -976,7 +976,7 @@ def __call__(

if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
image_latent = self.scheduler.scale_noise(
image_latent = self.scheduler.add_noise(
original_image_latents, torch.tensor([noise_timestep]), noise
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ def prepare_latents(
shape = init_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
init_latents = self.scheduler.add_noise(init_latents, timestep, noise)
latents = init_latents.to(device=device, dtype=dtype)

return latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def prepare_latents(
image_latents = torch.cat([image_latents], dim=0)

noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self.scheduler.add_noise(image_latents, timestep, noise)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, latent_image_ids

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def prepare_latents(

if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self.scheduler.add_noise(image_latents, timestep, noise)
else:
noise = latents.to(device)
latents = noise
Expand Down Expand Up @@ -1154,7 +1154,7 @@ def __call__(

if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
init_latents_proper = self.scheduler.scale_noise(
init_latents_proper = self.scheduler.add_noise(
init_latents_proper, torch.tensor([noise_timestep]), noise
)

Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/flux/pipeline_flux_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,7 @@ def prepare_latents(
image_latents = torch.cat([image_latents], dim=0)

noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self.scheduler.add_noise(image_latents, timestep, noise)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
return latents, latent_image_ids

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/flux/pipeline_flux_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ def prepare_latents(

if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self.scheduler.add_noise(image_latents, timestep, noise)
else:
noise = latents.to(device)
latents = noise
Expand Down Expand Up @@ -978,7 +978,7 @@ def __call__(

if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
init_latents_proper = self.scheduler.scale_noise(
init_latents_proper = self.scheduler.add_noise(
init_latents_proper, torch.tensor([noise_timestep]), noise
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)

# get latents
init_latents = self.scheduler.scale_noise(init_latents, timestep, noise)
init_latents = self.scheduler.add_noise(init_latents, timestep, noise)
latents = init_latents.to(device=device, dtype=dtype)

return latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ def prepare_latents(
if latents is None:
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# if strength is 1. then initialise the latents to noise, else initial to image + noise
latents = noise if is_strength_max else self.scheduler.scale_noise(image_latents, timestep, noise)
latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, timestep, noise)
else:
noise = latents.to(device)
latents = noise
Expand Down Expand Up @@ -1145,7 +1145,7 @@ def __call__(

if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
init_latents_proper = self.scheduler.scale_noise(
init_latents_proper = self.scheduler.add_noise(
init_latents_proper, torch.tensor([noise_timestep]), noise
)

Expand Down
157 changes: 115 additions & 42 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,21 @@ def __init__(
use_karras_sigmas: Optional[bool] = False,
use_exponential_sigmas: Optional[bool] = False,
use_beta_sigmas: Optional[bool] = False,
use_flow_match: Optional[bool] = False,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
timestep_spacing: str = "linspace",
timestep_type: str = "discrete", # can be "discrete" or "continuous"
steps_offset: int = 0,
rescale_betas_zero_snr: bool = False,
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
shift: float = 1.0,
use_dynamic_shifting=False,
base_shift: Optional[float] = 0.5,
max_shift: Optional[float] = 1.15,
base_image_seq_len: Optional[int] = 256,
max_image_seq_len: Optional[int] = 4096,
invert_sigmas: bool = False,
):
if self.config.use_beta_sigmas and not is_scipy_available():
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
Expand Down Expand Up @@ -234,20 +242,39 @@ def __init__(
# FP16 smallest positive subnormal works well here
self.alphas_cumprod[-1] = 2**-24

sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0)
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
if use_flow_match:
timestep_offset = 1
else:
timestep_offset = 0

timesteps = np.linspace(
0 + timestep_offset, num_train_timesteps - 1 + timestep_offset, num_train_timesteps, dtype=float
)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)

if use_flow_match:
sigmas = timesteps / num_train_timesteps
if not use_dynamic_shifting:
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
else:
sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0)

# setable values
self.num_inference_steps = None

# TODO: Support the full EDM scalings for all prediction types and timestep types
if timestep_type == "continuous" and prediction_type == "v_prediction":
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
elif use_flow_match:
self.timesteps = sigmas * num_train_timesteps
else:
self.timesteps = timesteps

self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if not use_flow_match:
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])

self.sigmas = sigmas

self.is_scale_input_called = False
self.use_karras_sigmas = use_karras_sigmas
Expand All @@ -257,6 +284,8 @@ def __init__(
self._step_index = None
self._begin_index = None
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()

@property
def init_noise_sigma(self):
Expand Down Expand Up @@ -322,6 +351,7 @@ def set_timesteps(
device: Union[str, torch.device] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
mu: Optional[float] = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Expand Down Expand Up @@ -362,57 +392,81 @@ def set_timesteps(
raise ValueError(
"Cannot set `timesteps` with `config.timestep_type = 'continuous'` and `config.prediction_type = 'v_prediction'`."
)
if timesteps is not None and self.config.use_flow_match:
# TODO: `timesteps / self.config.num_train_timesteps` to get sigmas?
raise ValueError("Cannot set `timesteps` with `config.use_flow_match = True`.")

if self.config.use_dynamic_shifting and mu is None:
raise ValueError(" you have a pass a value for `mu` when `use_dynamic_shifting` is set to be `True`")

if num_inference_steps is None:
num_inference_steps = len(timesteps) if timesteps is not None else len(sigmas) - 1
self.num_inference_steps = num_inference_steps

if sigmas is not None:
if sigmas is not None and not self.config.use_flow_match:
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
sigmas = np.array(sigmas).astype(np.float32)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])

else:
elif sigmas is None:
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.float32)
else:
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
if self.config.use_flow_match:
timesteps = np.linspace(
0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
)
timesteps -= 1
sigmas = timesteps / self.config.num_train_timesteps
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(
0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=np.float32
)[::-1].copy()
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (
(np.arange(self.config.num_train_timesteps, 0, -step_ratio))
.round()
.copy()
.astype(np.float32)
)
timesteps -= 1
else:
raise ValueError(
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
)
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = (
torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1)
.exp()
.numpy()
)
else:
raise ValueError(
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
" 'linear' or 'log_linear'"
)

if self.config.use_flow_match:
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)

sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
log_sigmas = np.log(sigmas)
if self.config.interpolation_type == "linear":
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
elif self.config.interpolation_type == "log_linear":
sigmas = torch.linspace(np.log(sigmas[-1]), np.log(sigmas[0]), num_inference_steps + 1).exp().numpy()
else:
raise ValueError(
f"{self.config.interpolation_type} is not implemented. Please specify interpolation_type to either"
" 'linear' or 'log_linear'"
)

if self.config.use_karras_sigmas:
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=self.num_inference_steps)
Expand All @@ -426,10 +480,16 @@ def set_timesteps(
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas])

if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps

if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
elif self.config.final_sigmas_type == "zero":
sigma_last = 0
elif self.config.invert_sigmas:
sigma_last = 1
else:
raise ValueError(
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
Expand All @@ -442,14 +502,21 @@ def set_timesteps(
# TODO: Support the full EDM scalings for all prediction types and timestep types
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
elif self.config.use_flow_match:
self.timesteps = sigmas * self.config.num_train_timesteps
else:
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)

self._step_index = None
self._begin_index = None
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication

def _sigma_to_t(self, sigma, log_sigmas):
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)

def _sigma_to_t(self, sigma, log_sigmas=None):
if self.config.use_flow_match:
return sigma * self.config.num_train_timesteps
# get log sigma
log_sigma = np.log(np.maximum(sigma, 1e-10))

Expand Down Expand Up @@ -622,7 +689,7 @@ def step(
),
)

if not self.is_scale_input_called:
if not self.is_scale_input_called and not self.config.use_flow_match:
logger.warning(
"The `scale_model_input` function should be called before `step` to ensure correct denoising. "
"See `StableDiffusionPipeline` for a usage example."
Expand Down Expand Up @@ -663,7 +730,10 @@ def step(
)

# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma_hat
if self.config.use_flow_match:
derivative = model_output
else:
derivative = (sample - pred_original_sample) / sigma_hat

dt = self.sigmas[self.step_index + 1] - sigma_hat

Expand Down Expand Up @@ -713,7 +783,10 @@ def add_noise(
while len(sigma.shape) < len(original_samples.shape):
sigma = sigma.unsqueeze(-1)

noisy_samples = original_samples + noise * sigma
if self.config.use_flow_match:
noisy_samples = (1.0 - sigma) * original_samples + noise * sigma
else:
noisy_samples = original_samples + noise * sigma
return noisy_samples

def get_velocity(self, sample: torch.Tensor, noise: torch.Tensor, timesteps: torch.Tensor) -> torch.Tensor:
Expand Down
Loading

0 comments on commit 4ede43b

Please sign in to comment.