Skip to content

Commit

Permalink
timesteps, set_timesteps(sigmas=..)
Browse files Browse the repository at this point in the history
  • Loading branch information
hlky committed Nov 23, 2024
1 parent 98a52db commit 4e9119f
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,14 @@ def __init__(
0 + timestep_offset, num_train_timesteps - 1 + timestep_offset, num_train_timesteps, dtype=float
)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
print(timesteps)

if use_flow_match:
sigmas = timesteps / num_train_timesteps
if not use_dynamic_shifting:
# when use_dynamic_shifting is True, we apply the timestep shifting on the fly based on the image resolution
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
print(sigmas)
else:
sigmas = (((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5).flip(0)

Expand All @@ -268,6 +270,7 @@ def __init__(
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas])
elif use_flow_match:
self.timesteps = sigmas * num_train_timesteps
print(self.timesteps)
else:
self.timesteps = timesteps

Expand Down Expand Up @@ -407,6 +410,19 @@ def set_timesteps(
log_sigmas = np.log(np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5))
sigmas = np.array(sigmas).astype(np.float32)
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas[:-1]])
elif sigmas is not None and self.config.use_flow_match:
sigmas = np.array(sigmas).astype(np.float32)
timesteps = sigmas * self.config.num_train_timesteps

if self.config.use_flow_match:
if self.config.use_dynamic_shifting:
sigmas = self.time_shift(mu, 1.0, sigmas)
else:
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)

if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
elif sigmas is None:
if timesteps is not None:
timesteps = np.array(timesteps).astype(np.float32)
Expand Down Expand Up @@ -482,7 +498,7 @@ def set_timesteps(

if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
timesteps = sigmas * self.config.num_train_timesteps

if self.config.final_sigmas_type == "sigma_min":
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
Expand All @@ -502,8 +518,6 @@ def set_timesteps(
# TODO: Support the full EDM scalings for all prediction types and timestep types
if self.config.timestep_type == "continuous" and self.config.prediction_type == "v_prediction":
self.timesteps = torch.Tensor([0.25 * sigma.log() for sigma in sigmas[:-1]]).to(device=device)
elif self.config.use_flow_match:
self.timesteps = sigmas[:-1] * self.config.num_train_timesteps
else:
self.timesteps = torch.from_numpy(timesteps.astype(np.float32)).to(device=device)

Expand Down

0 comments on commit 4e9119f

Please sign in to comment.