Skip to content

Commit 63c68d9

Browse files
Nathan Lambertpatrickvonplaten
Nathan Lambert
andauthored
VE/VP SDE updates (open-mmlab#90)
* improve comments for sde_ve scheduler, init tests * more comments, tweaking pipelines * timesteps --> num_training_timesteps, some comments * merge cpu test, add m1 data * fix scheduler tests with num_train_timesteps * make np compatible, add tests for sde ve * minor default variable fixes * make style and fix-copies Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
1 parent ba3c9a9 commit 63c68d9

10 files changed

+330
-81
lines changed

src/diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ def __call__(self, num_inference_steps=2000, generator=None):
1818

1919
model = self.model.to(device)
2020

21-
# TODO(Patrick) move to scheduler config
22-
n_steps = 1
23-
2421
x = torch.randn(*shape) * self.scheduler.config.sigma_max
2522
x = x.to(device)
2623

@@ -30,7 +27,7 @@ def __call__(self, num_inference_steps=2000, generator=None):
3027
for i, t in enumerate(self.scheduler.timesteps):
3128
sigma_t = self.scheduler.sigmas[i] * torch.ones(shape[0], device=device)
3229

33-
for _ in range(n_steps):
30+
for _ in range(self.scheduler.correct_steps):
3431
with torch.no_grad():
3532
result = self.model(x, sigma_t)
3633

src/diffusers/pipelines/score_sde_vp/pipeline_score_sde_vp.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __call__(self, num_inference_steps=1000, generator=None):
2727
t = t * torch.ones(shape[0], device=device)
2828
scaled_t = t * (num_inference_steps - 1)
2929

30+
# TODO add corrector
3031
with torch.no_grad():
3132
result = model(x, scaled_t)
3233

src/diffusers/schedulers/scheduling_ddim.py

+8-11
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def alpha_bar(time_step):
5151
class DDIMScheduler(SchedulerMixin, ConfigMixin):
5252
def __init__(
5353
self,
54-
timesteps=1000,
54+
num_train_timesteps=1000,
5555
beta_start=0.0001,
5656
beta_end=0.02,
5757
beta_schedule="linear",
@@ -62,7 +62,7 @@ def __init__(
6262
):
6363
super().__init__()
6464
self.register_to_config(
65-
timesteps=timesteps,
65+
num_train_timesteps=num_train_timesteps,
6666
beta_start=beta_start,
6767
beta_end=beta_end,
6868
beta_schedule=beta_schedule,
@@ -72,13 +72,13 @@ def __init__(
7272
)
7373

7474
if beta_schedule == "linear":
75-
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
75+
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
7676
elif beta_schedule == "scaled_linear":
7777
# this schedule is very specific to the latent diffusion model.
78-
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype=np.float32) ** 2
78+
self.betas = np.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=np.float32) ** 2
7979
elif beta_schedule == "squaredcos_cap_v2":
8080
# Glide cosine schedule
81-
self.betas = betas_for_alpha_bar(timesteps)
81+
self.betas = betas_for_alpha_bar(num_train_timesteps)
8282
else:
8383
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
8484

@@ -88,10 +88,7 @@ def __init__(
8888

8989
# setable values
9090
self.num_inference_steps = None
91-
self.timesteps = np.arange(0, self.config.timesteps)[::-1].copy()
92-
93-
self.tensor_format = tensor_format
94-
self.set_format(tensor_format=tensor_format)
91+
self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()
9592

9693
def _get_variance(self, timestep, prev_timestep):
9794
alpha_prod_t = self.alphas_cumprod[timestep]
@@ -131,7 +128,7 @@ def step(
131128
# - pred_prev_sample -> "x_t-1"
132129

133130
# 1. get previous step value (=t-1)
134-
prev_timestep = timestep - self.config.timesteps // self.num_inference_steps
131+
prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
135132

136133
# 2. compute alphas, betas
137134
alpha_prod_t = self.alphas_cumprod[timestep]
@@ -183,4 +180,4 @@ def add_noise(self, original_samples, noise, timesteps):
183180
return noisy_samples
184181

185182
def __len__(self):
186-
return self.config.timesteps
183+
return self.config.num_train_timesteps

src/diffusers/schedulers/scheduling_ddpm.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def alpha_bar(time_step):
5050
class DDPMScheduler(SchedulerMixin, ConfigMixin):
5151
def __init__(
5252
self,
53-
timesteps=1000,
53+
num_train_timesteps=1000,
5454
beta_start=0.0001,
5555
beta_end=0.02,
5656
beta_schedule="linear",
@@ -62,7 +62,7 @@ def __init__(
6262
):
6363
super().__init__()
6464
self.register_to_config(
65-
timesteps=timesteps,
65+
num_train_timesteps=num_train_timesteps,
6666
beta_start=beta_start,
6767
beta_end=beta_end,
6868
beta_schedule=beta_schedule,
@@ -75,10 +75,10 @@ def __init__(
7575
if trained_betas is not None:
7676
self.betas = np.asarray(trained_betas)
7777
elif beta_schedule == "linear":
78-
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
78+
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
7979
elif beta_schedule == "squaredcos_cap_v2":
8080
# Glide cosine schedule
81-
self.betas = betas_for_alpha_bar(timesteps)
81+
self.betas = betas_for_alpha_bar(num_train_timesteps)
8282
else:
8383
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
8484

@@ -160,4 +160,4 @@ def add_noise(self, original_samples, noise, timesteps):
160160
return noisy_samples
161161

162162
def __len__(self):
163-
return self.config.timesteps
163+
return self.config.num_train_timesteps

src/diffusers/schedulers/scheduling_pndm.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -50,25 +50,25 @@ def alpha_bar(time_step):
5050
class PNDMScheduler(SchedulerMixin, ConfigMixin):
5151
def __init__(
5252
self,
53-
timesteps=1000,
53+
num_train_timesteps=1000,
5454
beta_start=0.0001,
5555
beta_end=0.02,
5656
beta_schedule="linear",
5757
tensor_format="np",
5858
):
5959
super().__init__()
6060
self.register_to_config(
61-
timesteps=timesteps,
61+
num_train_timesteps=num_train_timesteps,
6262
beta_start=beta_start,
6363
beta_end=beta_end,
6464
beta_schedule=beta_schedule,
6565
)
6666

6767
if beta_schedule == "linear":
68-
self.betas = np.linspace(beta_start, beta_end, timesteps, dtype=np.float32)
68+
self.betas = np.linspace(beta_start, beta_end, num_train_timesteps, dtype=np.float32)
6969
elif beta_schedule == "squaredcos_cap_v2":
7070
# Glide cosine schedule
71-
self.betas = betas_for_alpha_bar(timesteps)
71+
self.betas = betas_for_alpha_bar(num_train_timesteps)
7272
else:
7373
raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
7474

@@ -96,10 +96,12 @@ def get_prk_time_steps(self, num_inference_steps):
9696
if num_inference_steps in self.prk_time_steps:
9797
return self.prk_time_steps[num_inference_steps]
9898

99-
inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
99+
inference_step_times = list(
100+
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
101+
)
100102

101103
prk_time_steps = np.array(inference_step_times[-self.pndm_order :]).repeat(2) + np.tile(
102-
np.array([0, self.config.timesteps // num_inference_steps // 2]), self.pndm_order
104+
np.array([0, self.config.num_train_timesteps // num_inference_steps // 2]), self.pndm_order
103105
)
104106
self.prk_time_steps[num_inference_steps] = list(reversed(prk_time_steps[:-1].repeat(2)[1:-1]))
105107

@@ -109,7 +111,9 @@ def get_time_steps(self, num_inference_steps):
109111
if num_inference_steps in self.time_steps:
110112
return self.time_steps[num_inference_steps]
111113

112-
inference_step_times = list(range(0, self.config.timesteps, self.config.timesteps // num_inference_steps))
114+
inference_step_times = list(
115+
range(0, self.config.num_train_timesteps, self.config.num_train_timesteps // num_inference_steps)
116+
)
113117
self.time_steps[num_inference_steps] = list(reversed(inference_step_times[:-3]))
114118

115119
return self.time_steps[num_inference_steps]
@@ -135,6 +139,10 @@ def step_prk(
135139
sample: Union[torch.FloatTensor, np.ndarray],
136140
num_inference_steps,
137141
):
142+
"""
143+
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
144+
solution to the differential equation.
145+
"""
138146
t = timestep
139147
prk_time_steps = self.get_prk_time_steps(num_inference_steps)
140148

@@ -165,6 +173,10 @@ def step_plms(
165173
sample: Union[torch.FloatTensor, np.ndarray],
166174
num_inference_steps,
167175
):
176+
"""
177+
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
178+
times to approximate the solution.
179+
"""
168180
t = timestep
169181
if len(self.ets) < 3:
170182
raise ValueError(
@@ -221,4 +233,4 @@ def get_prev_sample(self, sample, t_orig, t_orig_prev, model_output):
221233
return prev_sample
222234

223235
def __len__(self):
224-
return self.config.timesteps
236+
return self.config.num_train_timesteps

src/diffusers/schedulers/scheduling_sde_ve.py

+104-32
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
1616

1717
# TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
18+
import pdb
1819

1920
import numpy as np
2021
import torch
@@ -24,61 +25,132 @@
2425

2526

2627
class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin):
27-
def __init__(self, snr=0.15, sigma_min=0.01, sigma_max=1348, sampling_eps=1e-5, tensor_format="np"):
28+
"""
29+
The variance exploding stochastic differential equation (SDE) scheduler.
30+
31+
:param snr: coefficient weighting the step from the score sample (from the network) to the random noise. :param
32+
sigma_min: initial noise scale for sigma sequence in sampling procedure. The minimum sigma should mirror the
33+
distribution of the data.
34+
:param sigma_max: :param sampling_eps: the end value of sampling, where timesteps decrease progessively from 1 to
35+
epsilon. :param correct_steps: number of correction steps performed on a produced sample. :param tensor_format:
36+
"np" or "pt" for the expected format of samples passed to the Scheduler.
37+
"""
38+
39+
def __init__(
40+
self,
41+
num_train_timesteps=2000,
42+
snr=0.15,
43+
sigma_min=0.01,
44+
sigma_max=1348,
45+
sampling_eps=1e-5,
46+
correct_steps=1,
47+
tensor_format="pt",
48+
):
2849
super().__init__()
2950
self.register_to_config(
51+
num_train_timesteps=num_train_timesteps,
3052
snr=snr,
3153
sigma_min=sigma_min,
3254
sigma_max=sigma_max,
3355
sampling_eps=sampling_eps,
56+
correct_steps=correct_steps,
3457
)
3558

3659
self.sigmas = None
3760
self.discrete_sigmas = None
3861
self.timesteps = None
3962

63+
# TODO - update step to be torch-independant
64+
self.set_format(tensor_format=tensor_format)
65+
4066
def set_timesteps(self, num_inference_steps):
41-
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
67+
tensor_format = getattr(self, "tensor_format", "pt")
68+
if tensor_format == "np":
69+
self.timesteps = np.linspace(1, self.config.sampling_eps, num_inference_steps)
70+
elif tensor_format == "pt":
71+
self.timesteps = torch.linspace(1, self.config.sampling_eps, num_inference_steps)
72+
else:
73+
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
4274

4375
def set_sigmas(self, num_inference_steps):
4476
if self.timesteps is None:
4577
self.set_timesteps(num_inference_steps)
4678

47-
self.discrete_sigmas = torch.exp(
48-
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
49-
)
50-
self.sigmas = torch.tensor(
51-
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
52-
)
53-
54-
def step_pred(self, result, x, t):
79+
tensor_format = getattr(self, "tensor_format", "pt")
80+
if tensor_format == "np":
81+
self.discrete_sigmas = np.exp(
82+
np.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
83+
)
84+
self.sigmas = np.array(
85+
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
86+
)
87+
elif tensor_format == "pt":
88+
self.discrete_sigmas = torch.exp(
89+
torch.linspace(np.log(self.config.sigma_min), np.log(self.config.sigma_max), num_inference_steps)
90+
)
91+
self.sigmas = torch.tensor(
92+
[self.config.sigma_min * (self.config.sigma_max / self.sigma_min) ** t for t in self.timesteps]
93+
)
94+
else:
95+
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
96+
97+
def get_adjacent_sigma(self, timesteps, t):
98+
tensor_format = getattr(self, "tensor_format", "pt")
99+
if tensor_format == "np":
100+
return np.where(timesteps == 0, np.zeros_like(t), self.discrete_sigmas[timesteps - 1])
101+
elif tensor_format == "pt":
102+
return torch.where(
103+
timesteps == 0, torch.zeros_like(t), self.discrete_sigmas[timesteps - 1].to(timesteps.device)
104+
)
105+
106+
raise ValueError(f"`self.tensor_format`: {self.tensor_format} is not valid.")
107+
108+
def step_pred(self, score, x, t):
109+
"""
110+
Predict the sample at the previous timestep by reversing the SDE.
111+
"""
55112
# TODO(Patrick) better comments + non-PyTorch
56-
t = t * torch.ones(x.shape[0], device=x.device)
57-
timestep = (t * (len(self.timesteps) - 1)).long()
58-
59-
sigma = self.discrete_sigmas.to(t.device)[timestep]
60-
adjacent_sigma = torch.where(
61-
timestep == 0, torch.zeros_like(t), self.discrete_sigmas[timestep - 1].to(timestep.device)
62-
)
63-
f = torch.zeros_like(x)
64-
G = torch.sqrt(sigma**2 - adjacent_sigma**2)
113+
t = self.repeat_scalar(t, x.shape[0])
114+
timesteps = self.long((t * (len(self.timesteps) - 1)))
115+
116+
sigma = self.discrete_sigmas[timesteps]
117+
adjacent_sigma = self.get_adjacent_sigma(timesteps, t)
118+
drift = self.zeros_like(x)
119+
diffusion = (sigma**2 - adjacent_sigma**2) ** 0.5
120+
121+
# equation 6 in the paper: the score modeled by the network is grad_x log pt(x)
122+
# also equation 47 shows the analog from SDE models to ancestral sampling methods
123+
drift = drift - diffusion[:, None, None, None] ** 2 * score
124+
125+
# equation 6: sample noise for the diffusion term of
126+
noise = self.randn_like(x)
127+
x_mean = x - drift # subtract because `dt` is a small negative timestep
128+
# TODO is the variable diffusion the correct scaling term for the noise?
129+
x = x_mean + diffusion[:, None, None, None] * noise # add impact of diffusion field g
130+
return x, x_mean
65131

66-
f = f - G[:, None, None, None] ** 2 * result
132+
def step_correct(self, score, x):
133+
"""
134+
Correct the predicted sample based on the output score of the network. This is often run repeatedly after
135+
making the prediction for the previous timestep.
136+
"""
137+
# TODO(Patrick) non-PyTorch
67138

68-
z = torch.randn_like(x)
69-
x_mean = x - f
70-
x = x_mean + G[:, None, None, None] * z
71-
return x, x_mean
139+
# For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
140+
# sample noise for correction
141+
noise = self.randn_like(x)
72142

73-
def step_correct(self, result, x):
74-
# TODO(Patrick) better comments + non-PyTorch
75-
noise = torch.randn_like(x)
76-
grad_norm = torch.norm(result.reshape(result.shape[0], -1), dim=-1).mean()
77-
noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
143+
# compute step size from the score, the noise, and the snr
144+
grad_norm = self.norm(score)
145+
noise_norm = self.norm(noise)
78146
step_size = (self.config.snr * noise_norm / grad_norm) ** 2 * 2
79-
step_size = step_size * torch.ones(x.shape[0], device=x.device)
80-
x_mean = x + step_size[:, None, None, None] * result
147+
step_size = self.repeat_scalar(step_size, x.shape[0]) # * self.ones(x.shape[0], device=x.device)
81148

82-
x = x_mean + torch.sqrt(step_size * 2)[:, None, None, None] * noise
149+
# compute corrected sample: score term and noise term
150+
x_mean = x + step_size[:, None, None, None] * score
151+
x = x_mean + ((step_size * 2) ** 0.5)[:, None, None, None] * noise
83152

84153
return x
154+
155+
def __len__(self):
156+
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)