-
Notifications
You must be signed in to change notification settings - Fork 13
/
solver_utils.py
403 lines (343 loc) · 17 KB
/
solver_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
import torch
import numpy as np
#----------------------------------------------------------------------------
def get_schedule(num_steps, sigma_min, sigma_max, device=None, schedule_type='polynomial', schedule_rho=7, net=None, dp_list=None):
"""
Get the time schedule for sampling.
Args:
num_steps: A `int`. The total number of the time steps with `num_steps-1` spacings.
sigma_min: A `float`. The ending sigma during samping.
sigma_max: A `float`. The starting sigma during sampling.
device: A torch device.
schedule_type: A `str`. The type of time schedule. We support three types:
- 'polynomial': polynomial time schedule. (Recommended in EDM.)
- 'logsnr': uniform logSNR time schedule. (Recommended in DPM-Solver for small-resolution datasets.)
- 'time_uniform': uniform time schedule. (Recommended in DPM-Solver for high-resolution datasets.)
- 'discrete': time schedule used in LDM. (Recommended when using pre-trained diffusion models from the LDM and Stable Diffusion codebases.)
schedule_type: A `float`. Time step exponent.
net: A pre-trained diffusion model. Required when schedule_type == 'discrete'.
Returns:
a PyTorch tensor with shape [num_steps].
"""
if schedule_type == 'polynomial':
step_indices = torch.arange(num_steps, device=device)
t_steps = (sigma_max ** (1 / schedule_rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / schedule_rho) - sigma_max ** (1 / schedule_rho))) ** schedule_rho
elif schedule_type == 'logsnr':
logsnr_max = -1 * torch.log(torch.tensor(sigma_min))
logsnr_min = -1 * torch.log(torch.tensor(sigma_max))
t_steps = torch.linspace(logsnr_min.item(), logsnr_max.item(), steps=num_steps, device=device)
t_steps = (-t_steps).exp()
elif schedule_type == 'time_uniform':
epsilon_s = 1e-3
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
step_indices = torch.arange(num_steps, device=device)
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
t_steps_temp = (1 + step_indices / (num_steps - 1) * (epsilon_s ** (1 / schedule_rho) - 1)) ** schedule_rho
t_steps = vp_sigma(vp_beta_d.clone().detach().cpu(), vp_beta_min.clone().detach().cpu())(t_steps_temp.clone().detach().cpu())
elif schedule_type == 'discrete':
assert net is not None
t_steps_min = net.sigma_inv(torch.tensor(sigma_min, device=device))
t_steps_max = net.sigma_inv(torch.tensor(sigma_max, device=device))
step_indices = torch.arange(num_steps, device=device)
t_steps_temp = (t_steps_max + step_indices / (num_steps - 1) * (t_steps_min ** (1 / schedule_rho) - t_steps_max)) ** schedule_rho
t_steps = net.sigma(t_steps_temp)
else:
raise ValueError("Got wrong schedule type {}".format(schedule_type))
if dp_list is not None:
return t_steps[dp_list].to(device)
return t_steps.to(device)
# Copied from the DPM-Solver codebase (https://github.com/LuChengTHU/dpm-solver).
# Different from the original codebase, we use the VE-SDE formulation for simplicity
# while the official implementation uses the equivalent VP-SDE formulation.
##############################
### Utils for DPM-Solver++ ###
##############################
#----------------------------------------------------------------------------
def expand_dims(v, dims):
"""
Expand the tensor `v` to the dim `dims`.
Args:
v: a PyTorch tensor with shape [N].
dim: a `int`.
Returns:
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
"""
return v[(...,) + (None,)*(dims - 1)]
#----------------------------------------------------------------------------
def dynamic_thresholding_fn(x0):
"""
The dynamic thresholding method
"""
dims = x0.dim()
p = 0.995
s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
s = expand_dims(torch.maximum(s, 1. * torch.ones_like(s).to(s.device)), dims)
x0 = torch.clamp(x0, -s, s) / s
return x0
#----------------------------------------------------------------------------
def dpm_pp_update(x, model_prev_list, t_prev_list, t, order, predict_x0=True):
if order == 1:
return dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1], predict_x0=predict_x0)
elif order == 2:
return multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, predict_x0=predict_x0)
elif order == 3:
return multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, predict_x0=predict_x0)
else:
raise ValueError("Solver order must be 1 or 2 or 3, got {}".format(order))
#----------------------------------------------------------------------------
def dpm_solver_first_update(x, s, t, model_s=None, predict_x0=True):
s, t = s.reshape(-1, 1, 1, 1), t.reshape(-1, 1, 1, 1)
lambda_s, lambda_t = -1 * s.log(), -1 * t.log()
h = lambda_t - lambda_s
phi_1 = torch.expm1(-h) if predict_x0 else torch.expm1(h)
# VE-SDE formulation
if predict_x0:
x_t = (t / s) * x - phi_1 * model_s
else:
x_t = x - t * phi_1 * model_s
return x_t
#----------------------------------------------------------------------------
def multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, predict_x0=True):
t = t.reshape(-1, 1, 1, 1)
model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
t_prev_1, t_prev_0 = t_prev_list[-2].reshape(-1, 1, 1, 1), t_prev_list[-1].reshape(-1, 1, 1, 1)
lambda_prev_1, lambda_prev_0, lambda_t = -1 * t_prev_1.log(), -1 * t_prev_0.log(), -1 * t.log()
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0 = h_0 / h
D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
phi_1 = torch.expm1(-h) if predict_x0 else torch.expm1(h)
# VE-SDE formulation
if predict_x0:
x_t = (t / t_prev_0) * x - phi_1 * model_prev_0 - 0.5 * phi_1 * D1_0
else:
x_t = x - t * phi_1 * model_prev_0 - 0.5 * t * phi_1 * D1_0
return x_t
#----------------------------------------------------------------------------
def multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, predict_x0=True):
t = t.reshape(-1, 1, 1, 1)
model_prev_2, model_prev_1, model_prev_0 = model_prev_list[-3], model_prev_list[-2], model_prev_list[-1]
t_prev_2, t_prev_1, t_prev_0 = t_prev_list[-3], t_prev_list[-2], t_prev_list[-1]
t_prev_2, t_prev_1, t_prev_0 = t_prev_2.reshape(-1, 1, 1, 1), t_prev_1.reshape(-1, 1, 1, 1), t_prev_0.reshape(-1, 1, 1, 1)
lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = -1 * t_prev_2.log(), -1 * t_prev_1.log(), -1 * t_prev_0.log(), -1 * t.log()
h_1 = lambda_prev_1 - lambda_prev_2
h_0 = lambda_prev_0 - lambda_prev_1
h = lambda_t - lambda_prev_0
r0, r1 = h_0 / h, h_1 / h
D1_0 = (1. / r0) * (model_prev_0 - model_prev_1)
D1_1 = (1. / r1) * (model_prev_1 - model_prev_2)
D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
D2 = (1. / (r0 + r1)) * (D1_0 - D1_1)
phi_1 = torch.expm1(-h) if predict_x0 else torch.expm1(h)
phi_2 = phi_1 / h + 1. if predict_x0 else phi_1 / h - 1.
phi_3 = phi_2 / h - 0.5
# VE-SDE formulation
if predict_x0:
x_t = (t / t_prev_0) * x - phi_1 * model_prev_0 + phi_2 * D1 - phi_3 * D2
else:
x_t = x - t * phi_1 * model_prev_0 - t * phi_2 * D1 - t * phi_3 * D2
return x_t
# Copied from the UniPC codebase (https://github.com/wl-zhao/UniPC).
# Different from the original codebase, we use the VE-SDE formulation for simplicity
# while the official implementation uses the equivalent VP-SDE formulation.
##############################
### Utils for UniPC solver ###
##############################
#----------------------------------------------------------------------------
def unipc_update(
x, model_prev_list, t_prev_list, t, order, x_t=None, variant='bh1', predict_x0=True,
net=None, class_labels=None, use_corrector=True,
):
assert order <= len(model_prev_list)
# first compute rks
t_prev_0 = t_prev_list[-1].reshape(1,)
t = t.reshape(1,)
lambda_prev_0 = -1 * t_prev_0.log()
lambda_t = -1 * t.log()
model_prev_0 = model_prev_list[-1]
h = lambda_t - lambda_prev_0
rks = []
D1s = []
for i in range(1, order):
t_prev_i = t_prev_list[-(i + 1)].reshape(1,)
model_prev_i = model_prev_list[-(i + 1)]
lambda_prev_i = -1 * t_prev_i.log()
rk = (lambda_prev_i - lambda_prev_0) / h
rks.append(rk)
D1s.append((model_prev_i - model_prev_0) / rk)
rks.append(1.)
rks = torch.tensor(rks, device=x.device)
R = []
b = []
hh = -h if predict_x0 else h
h_phi_1 = torch.expm1(hh)
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if variant == 'bh1':
B_h = hh
elif variant == 'bh2':
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= (i + 1)
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.cat(b)
# now predictor
use_predictor = len(D1s) > 0 and x_t is None
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
if x_t is None:
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], device=b.device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
else:
D1s = None
if use_corrector:
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], device=b.device)
else:
rhos_c = torch.linalg.solve(R, b)
model_t = None
# data prediction
if predict_x0:
x_t_ = t / t_prev_0 * x - h_phi_1 * model_prev_0
if x_t is None:
if use_predictor:
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - B_h * pred_res
if use_corrector:
model_t = net(x_t, t, class_labels)
model_t = dynamic_thresholding_fn(model_t)
if D1s is not None:
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = (model_t - model_prev_0)
x_t = x_t_ - B_h * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = x - t * h_phi_1 * model_prev_0
if x_t is None:
if use_predictor:
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - t * B_h * pred_res
if use_corrector:
denoised = net(x_t, t, class_labels)
model_t = (x_t - denoised) / t
if D1s is not None:
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = (model_t - model_prev_0)
x_t = x_t_ - t * B_h * (corr_res + rhos_c[-1] * D1_t)
return x_t, model_t
# A pytorch reimplementation of DEIS (https://github.com/qsh-zh/deis).
#############################
### Utils for DEIS solver ###
#############################
#----------------------------------------------------------------------------
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d
t_steps = vp_sigma_inv(vp_beta_d.clone().detach().cpu(), vp_beta_min.clone().detach().cpu())(edm_steps.clone().detach().cpu())
return t_steps, vp_beta_min, vp_beta_d + vp_beta_min
#----------------------------------------------------------------------------
def cal_poly(prev_t, j, taus):
poly = 1
for k in range(prev_t.shape[0]):
if k == j:
continue
poly *= (taus - prev_t[k]) / (prev_t[j] - prev_t[k])
return poly
#----------------------------------------------------------------------------
# Transfer from t to alpha_t.
def t2alpha_fn(beta_0, beta_1, t):
return torch.exp(-0.5 * t ** 2 * (beta_1 - beta_0) - t * beta_0)
#----------------------------------------------------------------------------
def cal_intergrand(beta_0, beta_1, taus):
with torch.enable_grad():
taus.requires_grad_(True)
alpha = t2alpha_fn(beta_0, beta_1, taus)
log_alpha = alpha.log()
log_alpha.sum().backward()
d_log_alpha_dtau = taus.grad
integrand = -0.5 * d_log_alpha_dtau / torch.sqrt(alpha * (1 - alpha))
return integrand
#----------------------------------------------------------------------------
def get_deis_coeff_list(t_steps, max_order, N=10000, deis_mode='tab'):
"""
Get the coefficient list for DEIS sampling.
Args:
t_steps: A pytorch tensor. The time steps for sampling.
max_order: A `int`. Maximum order of the solver. 1 <= max_order <= 4
N: A `int`. Use how many points to perform the numerical integration when deis_mode=='tab'.
deis_mode: A `str`. Select between 'tab' and 'rhoab'. Type of DEIS.
Returns:
A pytorch tensor. A batch of generated samples or sampling trajectories if return_inters=True.
"""
if deis_mode == 'tab':
t_steps, beta_0, beta_1 = edm2t(t_steps)
C = []
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
order = min(i+1, max_order)
if order == 1:
C.append([])
else:
taus = torch.linspace(t_cur, t_next, N) # split the interval for integral appximation
dtau = (t_next - t_cur) / N
prev_t = t_steps[[i - k for k in range(order)]]
coeff_temp = []
integrand = cal_intergrand(beta_0, beta_1, taus)
for j in range(order):
poly = cal_poly(prev_t, j, taus)
coeff_temp.append(torch.sum(integrand * poly) * dtau)
C.append(coeff_temp)
elif deis_mode == 'rhoab':
# Analytical solution, second order
def get_def_intergral_2(a, b, start, end, c):
coeff = (end**3 - start**3) / 3 - (end**2 - start**2) * (a + b) / 2 + (end - start) * a * b
return coeff / ((c - a) * (c - b))
# Analytical solution, third order
def get_def_intergral_3(a, b, c, start, end, d):
coeff = (end**4 - start**4) / 4 - (end**3 - start**3) * (a + b + c) / 3 \
+ (end**2 - start**2) * (a*b + a*c + b*c) / 2 - (end - start) * a * b * c
return coeff / ((d - a) * (d - b) * (d - c))
C = []
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
order = min(i, max_order)
if order == 0:
C.append([])
else:
prev_t = t_steps[[i - k for k in range(order+1)]]
if order == 1:
coeff_cur = ((t_next - prev_t[1])**2 - (t_cur - prev_t[1])**2) / (2 * (t_cur - prev_t[1]))
coeff_prev1 = (t_next - t_cur)**2 / (2 * (prev_t[1] - t_cur))
coeff_temp = [coeff_cur, coeff_prev1]
elif order == 2:
coeff_cur = get_def_intergral_2(prev_t[1], prev_t[2], t_cur, t_next, t_cur)
coeff_prev1 = get_def_intergral_2(t_cur, prev_t[2], t_cur, t_next, prev_t[1])
coeff_prev2 = get_def_intergral_2(t_cur, prev_t[1], t_cur, t_next, prev_t[2])
coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2]
elif order == 3:
coeff_cur = get_def_intergral_3(prev_t[1], prev_t[2], prev_t[3], t_cur, t_next, t_cur)
coeff_prev1 = get_def_intergral_3(t_cur, prev_t[2], prev_t[3], t_cur, t_next, prev_t[1])
coeff_prev2 = get_def_intergral_3(t_cur, prev_t[1], prev_t[3], t_cur, t_next, prev_t[2])
coeff_prev3 = get_def_intergral_3(t_cur, prev_t[1], prev_t[2], t_cur, t_next, prev_t[3])
coeff_temp = [coeff_cur, coeff_prev1, coeff_prev2, coeff_prev3]
C.append(coeff_temp)
return C