-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add PLMS sampling and do one 2x size batch per sampling step #51
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,236 @@ | ||
"""SAMPLING ONLY.""" | ||
|
||
import torch | ||
import numpy as np | ||
from tqdm import tqdm | ||
from functools import partial | ||
|
||
from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like | ||
|
||
|
||
class PLMSSampler(object): | ||
def __init__(self, model, schedule="linear", **kwargs): | ||
super().__init__() | ||
self.model = model | ||
self.ddpm_num_timesteps = model.num_timesteps | ||
self.schedule = schedule | ||
|
||
def register_buffer(self, name, attr): | ||
if type(attr) == torch.Tensor: | ||
if attr.device != torch.device("cuda"): | ||
attr = attr.to(torch.device("cuda")) | ||
setattr(self, name, attr) | ||
|
||
def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): | ||
if ddim_eta != 0: | ||
raise ValueError('ddim_eta must be 0 for PLMS') | ||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, | ||
num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) | ||
alphas_cumprod = self.model.alphas_cumprod | ||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' | ||
to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) | ||
|
||
self.register_buffer('betas', to_torch(self.model.betas)) | ||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) | ||
self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) | ||
|
||
# calculations for diffusion q(x_t | x_{t-1}) and others | ||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) | ||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) | ||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) | ||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) | ||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) | ||
|
||
# ddim sampling parameters | ||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), | ||
ddim_timesteps=self.ddim_timesteps, | ||
eta=ddim_eta,verbose=verbose) | ||
self.register_buffer('ddim_sigmas', ddim_sigmas) | ||
self.register_buffer('ddim_alphas', ddim_alphas) | ||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) | ||
self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) | ||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( | ||
(1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( | ||
1 - self.alphas_cumprod / self.alphas_cumprod_prev)) | ||
self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) | ||
|
||
@torch.no_grad() | ||
def sample(self, | ||
S, | ||
batch_size, | ||
shape, | ||
conditioning=None, | ||
callback=None, | ||
normals_sequence=None, | ||
img_callback=None, | ||
quantize_x0=False, | ||
eta=0., | ||
mask=None, | ||
x0=None, | ||
temperature=1., | ||
noise_dropout=0., | ||
score_corrector=None, | ||
corrector_kwargs=None, | ||
verbose=True, | ||
x_T=None, | ||
log_every_t=100, | ||
unconditional_guidance_scale=1., | ||
unconditional_conditioning=None, | ||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... | ||
**kwargs | ||
): | ||
if conditioning is not None: | ||
if isinstance(conditioning, dict): | ||
cbs = conditioning[list(conditioning.keys())[0]].shape[0] | ||
if cbs != batch_size: | ||
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") | ||
else: | ||
if conditioning.shape[0] != batch_size: | ||
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") | ||
|
||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) | ||
# sampling | ||
C, H, W = shape | ||
size = (batch_size, C, H, W) | ||
print(f'Data shape for PLMS sampling is {size}') | ||
|
||
samples, intermediates = self.plms_sampling(conditioning, size, | ||
callback=callback, | ||
img_callback=img_callback, | ||
quantize_denoised=quantize_x0, | ||
mask=mask, x0=x0, | ||
ddim_use_original_steps=False, | ||
noise_dropout=noise_dropout, | ||
temperature=temperature, | ||
score_corrector=score_corrector, | ||
corrector_kwargs=corrector_kwargs, | ||
x_T=x_T, | ||
log_every_t=log_every_t, | ||
unconditional_guidance_scale=unconditional_guidance_scale, | ||
unconditional_conditioning=unconditional_conditioning, | ||
) | ||
return samples, intermediates | ||
|
||
@torch.no_grad() | ||
def plms_sampling(self, cond, shape, | ||
x_T=None, ddim_use_original_steps=False, | ||
callback=None, timesteps=None, quantize_denoised=False, | ||
mask=None, x0=None, img_callback=None, log_every_t=100, | ||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | ||
unconditional_guidance_scale=1., unconditional_conditioning=None,): | ||
device = self.model.betas.device | ||
b = shape[0] | ||
if x_T is None: | ||
img = torch.randn(shape, device=device) | ||
else: | ||
img = x_T | ||
|
||
if timesteps is None: | ||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps | ||
elif timesteps is not None and not ddim_use_original_steps: | ||
subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 | ||
timesteps = self.ddim_timesteps[:subset_end] | ||
|
||
intermediates = {'x_inter': [img], 'pred_x0': [img]} | ||
time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) | ||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] | ||
print(f"Running PLMS Sampling with {total_steps} timesteps") | ||
|
||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) | ||
old_eps = [] | ||
|
||
for i, step in enumerate(iterator): | ||
index = total_steps - i - 1 | ||
ts = torch.full((b,), step, device=device, dtype=torch.long) | ||
ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) | ||
|
||
if mask is not None: | ||
assert x0 is not None | ||
img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? | ||
img = img_orig * mask + (1. - mask) * img | ||
|
||
outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, | ||
quantize_denoised=quantize_denoised, temperature=temperature, | ||
noise_dropout=noise_dropout, score_corrector=score_corrector, | ||
corrector_kwargs=corrector_kwargs, | ||
unconditional_guidance_scale=unconditional_guidance_scale, | ||
unconditional_conditioning=unconditional_conditioning, | ||
old_eps=old_eps, t_next=ts_next) | ||
img, pred_x0, e_t = outs | ||
old_eps.append(e_t) | ||
if len(old_eps) >= 4: | ||
old_eps.pop(0) | ||
if callback: callback(i) | ||
if img_callback: img_callback(pred_x0, i) | ||
|
||
if index % log_every_t == 0 or index == total_steps - 1: | ||
intermediates['x_inter'].append(img) | ||
intermediates['pred_x0'].append(pred_x0) | ||
|
||
return img, intermediates | ||
|
||
@torch.no_grad() | ||
def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, | ||
temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, | ||
unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): | ||
b, *_, device = *x.shape, x.device | ||
|
||
def get_model_output(x, t): | ||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.: | ||
e_t = self.model.apply_model(x, t, c) | ||
else: | ||
x_in = torch.cat([x] * 2) | ||
t_in = torch.cat([t] * 2) | ||
c_in = torch.cat([unconditional_conditioning, c]) | ||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) | ||
e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) | ||
|
||
if score_corrector is not None: | ||
assert self.model.parameterization == "eps" | ||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) | ||
|
||
return e_t | ||
|
||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas | ||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev | ||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas | ||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas | ||
|
||
def get_x_prev_and_pred_x0(e_t, index): | ||
# select parameters corresponding to the currently considered timestep | ||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) | ||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) | ||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) | ||
sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) | ||
|
||
# current prediction for x_0 | ||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() | ||
if quantize_denoised: | ||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) | ||
# direction pointing to x_t | ||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t | ||
noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature | ||
if noise_dropout > 0.: | ||
noise = torch.nn.functional.dropout(noise, p=noise_dropout) | ||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise | ||
return x_prev, pred_x0 | ||
|
||
e_t = get_model_output(x, t) | ||
if len(old_eps) == 0: | ||
# Pseudo Improved Euler (2nd order) | ||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) | ||
e_t_next = get_model_output(x_prev, t_next) | ||
e_t_prime = (e_t + e_t_next) / 2 | ||
elif len(old_eps) == 1: | ||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth) | ||
e_t_prime = (3 * e_t - old_eps[-1]) / 2 | ||
elif len(old_eps) == 2: | ||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth) | ||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 | ||
elif len(old_eps) >= 3: | ||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth) | ||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 | ||
|
||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) | ||
|
||
return x_prev, pred_x0, e_t |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, love it!