Skip to content

Commit

Permalink
VE/VP SDE updates (open-mmlab#90)
Browse files Browse the repository at this point in the history
* improve comments for sde_ve scheduler, init tests

* more comments, tweaking pipelines

* timesteps --> num_training_timesteps, some comments

* merge cpu test, add m1 data

* fix scheduler tests with num_train_timesteps

* make np compatible, add tests for sde ve

* minor default variable fixes

* make style and fix-copies

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
  • Loading branch information
Nathan Lambert and patrickvonplaten committed Jul 18, 2022
1 parent ba3c9a9 commit 63c68d9
Show file tree
Hide file tree
Showing 10 changed files with 330 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def __call__(self, num_inference_steps=2000, generator=None):

model = self.model.to(device)

# TODO(Patrick) move to scheduler config
n_steps = 1

x = torch.randn(*shape) * self.scheduler.config.sigma_max
x = x.to(device)

Expand All @@ -30,7 +27,7 @@ def __call__(self, num_inference_steps=2000, generator=None):
for i, t in enumerate(self.scheduler.timesteps):
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)

for _ in range(n_steps):
for _ in range(self.scheduler.correct_steps):
with torch.no_grad():
result = self.model(x, sigma_t)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __call__(self, num_inference_steps=1000, generator=None):
t = t * torch.ones(shape[0], device=device)
scaled_t = t * (num_inference_steps - 1)

# TODO add corrector
with torch.no_grad():
result = model(x, scaled_t)

Expand Down
19 changes: 8 additions & 11 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def alpha_bar(time_step):
class DDIMScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
timesteps=1000,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
Expand All @@ -62,7 +62,7 @@ def __init__(
):
super().__init__()
self.register_to_config(
timesteps=timesteps,
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
Expand All @@ -72,13 +72,13 @@ def __init__(
)

if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=np.float32) ** 2
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps)
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

Expand All @@ -88,10 +88,7 @@ def __init__(

# setable values
self.num_inference_steps = None
self.timesteps = np.arange(0, self.config.timesteps)[::-1].copy()

self.tensor_format = tensor_format
self.set_format(tensor_format=tensor_format)
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()

def _get_variance(self, timestep, prev_timestep):
alpha_prod_t = self.alphas_cumprod[timestep]
Expand Down Expand Up @@ -131,7 +128,7 @@ def step(
# - pred_prev_sample -> "x_t-1"

# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.timesteps // self.num_inference_steps
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps

# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
Expand Down Expand Up @@ -183,4 +180,4 @@ def add_noise(self, original_samples, noise, timesteps):
return noisy_samples

def __len__(self):
return self.config.timesteps
return self.config.num_train_timesteps
10 changes: 5 additions & 5 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def alpha_bar(time_step):
class DDPMScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
timesteps=1000,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
Expand All @@ -62,7 +62,7 @@ def __init__(
):
super().__init__()
self.register_to_config(
timesteps=timesteps,
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
Expand All @@ -75,10 +75,10 @@ def __init__(
if trained_betas is not None:
self.betas = np.asarray(trained_betas)
elif beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps)
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

Expand Down Expand Up @@ -160,4 +160,4 @@ def add_noise(self, original_samples, noise, timesteps):
return noisy_samples

def __len__(self):
return self.config.timesteps
return self.config.num_train_timesteps
28 changes: 20 additions & 8 deletions src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,25 +50,25 @@ def alpha_bar(time_step):
class PNDMScheduler(SchedulerMixin, ConfigMixin):
def __init__(
self,
timesteps=1000,
num_train_timesteps=1000,
beta_start=0.0001,
beta_end=0.02,
beta_schedule="linear",
tensor_format="np",
):
super().__init__()
self.register_to_config(
timesteps=timesteps,
num_train_timesteps=num_train_timesteps,
beta_start=beta_start,
beta_end=beta_end,
beta_schedule=beta_schedule,
)

if beta_schedule == "linear":
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
elif beta_schedule == "squaredcos_cap_v2":
# Glide cosine schedule
self.betas = betas_for_alpha_bar(timesteps)
self.betas = betas_for_alpha_bar(num_train_timesteps)
else:
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")

Expand Down Expand Up @@ -96,10 +96,12 @@ def get_prk_time_steps(self, num_inference_steps):
if num_inference_steps in self.prk_time_steps:
return self.prk_time_steps[num_inference_steps]

inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
inference_step_times = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)

prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
)
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))

Expand All @@ -109,7 +111,9 @@ def get_time_steps(self, num_inference_steps):
if num_inference_steps in self.time_steps:
return self.time_steps[num_inference_steps]

inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
inference_step_times = list(
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
)
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))

return self.time_steps[num_inference_steps]
Expand All @@ -135,6 +139,10 @@ def step_prk(
sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
):
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
solution to the differential equation.
"""
t = timestep
prk_time_steps = self.get_prk_time_steps(num_inference_steps)

Expand Down Expand Up @@ -165,6 +173,10 @@ def step_plms(
sample: Union[torch.FloatTensor, np.ndarray],
num_inference_steps,
):
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
times to approximate the solution.
"""
t = timestep
if len(self.ets) < 3:
raise ValueError(
Expand Down Expand Up @@ -221,4 +233,4 @@ def get_prev_sample(self, sample, t_orig, t_orig_prev, model_output):
return prev_sample

def __len__(self):
return self.config.timesteps
return self.config.num_train_timesteps
136 changes: 104 additions & 32 deletions src/diffusers/schedulers/scheduling_sde_ve.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch

# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
import pdb

import numpy as np
import torch
Expand All @@ -24,61 +25,132 @@


class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"):
"""
The variance exploding stochastic differential equation (SDE) scheduler.
:param snr: coefficient weighting the step from the score sample (from the network) to the random noise. :param
sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
distribution of the data.
:param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to
epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format:
"np" or "pt" for the expected format of samples passed to the Scheduler.
"""

def __init__(
self,
num_train_timesteps=2000,
snr=0.15,
sigma_min=0.01,
sigma_max=1348,
sampling_eps=1e-5,
correct_steps=1,
tensor_format="pt",
):
super().__init__()
self.register_to_config(
num_train_timesteps=num_train_timesteps,
snr=snr,
sigma_min=sigma_min,
sigma_max=sigma_max,
sampling_eps=sampling_eps,
correct_steps=correct_steps,
)

self.sigmas = None
self.discrete_sigmas = None
self.timesteps = None

# TODO - update step to be torch-independant
self.set_format(tensor_format=tensor_format)

def set_timesteps(self, num_inference_steps):
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
self.timesteps = np.linspace(1, self.config.sampling_eps, num_inference_steps)
elif tensor_format == "pt":
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")

def set_sigmas(self, num_inference_steps):
if self.timesteps is None:
self.set_timesteps(num_inference_steps)

self.discrete_sigmas = torch.exp(
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
)
self.sigmas = torch.tensor(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)

def step_pred(self, result, x, t):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
self.discrete_sigmas = np.exp(
np.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
)
self.sigmas = np.array(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
elif tensor_format == "pt":
self.discrete_sigmas = torch.exp(
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
)
self.sigmas = torch.tensor(
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
)
else:
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")

def get_adjacent_sigma(self, timesteps, t):
tensor_format = getattr(self, "tensor_format", "pt")
if tensor_format == "np":
return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
elif tensor_format == "pt":
return torch.where(
timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device)
)

raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")

def step_pred(self, score, x, t):
"""
Predict the sample at the previous timestep by reversing the SDE.
"""
# TODO(Patrick) better comments + non-PyTorch
t = t * torch.ones(x.shape[0], device=x.device)
timestep = (t * (len(self.timesteps) - 1)).long()

sigma = self.discrete_sigmas.to(t.device)[timestep]
adjacent_sigma = torch.where(
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device)
)
f = torch.zeros_like(x)
G = torch.sqrt(sigma**2 - adjacent_sigma**2)
t = self.repeat_scalar(t, x.shape[0])
timesteps = self.long((t * (len(self.timesteps) - 1)))

sigma = self.discrete_sigmas[timesteps]
adjacent_sigma = self.get_adjacent_sigma(timesteps, t)
drift = self.zeros_like(x)
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5

# equation 6 in the paper: the score modeled by the network is grad_x log pt(x)
# also equation 47 shows the analog from SDE models to ancestral sampling methods
drift = drift - diffusion[:, None, None, None] ** 2 * score

# equation 6: sample noise for the diffusion term of
noise = self.randn_like(x)
x_mean = x - drift # subtract because `dt` is a small negative timestep
# TODO is the variable diffusion the correct scaling term for the noise?
x = x_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
return x, x_mean

f = f - G[:, None, None, None] ** 2 * result
def step_correct(self, score, x):
"""
Correct the predicted sample based on the output score of the network. This is often run repeatedly after
making the prediction for the previous timestep.
"""
# TODO(Patrick) non-PyTorch

z = torch.randn_like(x)
x_mean = x - f
x = x_mean + G[:, None, None, None] * z
return x, x_mean
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
# sample noise for correction
noise = self.randn_like(x)

def step_correct(self, result, x):
# TODO(Patrick) better comments + non-PyTorch
noise = torch.randn_like(x)
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
# compute step size from the score, the noise, and the snr
grad_norm = self.norm(score)
noise_norm = self.norm(noise)
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
step_size = step_size * torch.ones(x.shape[0], device=x.device)
x_mean = x + step_size[:, None, None, None] * result
step_size = self.repeat_scalar(step_size, x.shape[0]) # * self.ones(x.shape[0], device=x.device)

x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
# compute corrected sample: score term and noise term
x_mean = x + step_size[:, None, None, None] * score
x = x_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise

return x

def __len__(self):
return self.config.num_train_timesteps
Loading

0 comments on commit 63c68d9

Please sign in to comment.