Skip to content

Commit

Permalink
Merge pull request #16035 from v0xie/cfgpp
Browse files Browse the repository at this point in the history
Add new sampler DDIM CFG++
  • Loading branch information
AUTOMATIC1111 authored Jul 6, 2024
2 parents ace00a1 + 663a4d8 commit eb112c6
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 0 deletions.
10 changes: 10 additions & 0 deletions modules/sd_samplers_cfg_denoiser.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(self, sampler):
self.model_wrap = None
self.p = None

self.last_noise_uncond = None

# NOTE: masking before denoising can cause the original latents to be oversmoothed
# as the original latents do not have noise
self.mask_before_denoising = False
Expand Down Expand Up @@ -160,6 +162,8 @@ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
# so is_edit_model is set to False to support AND composition.
is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0

is_cfg_pp = 'CFG++' in self.sampler.config.name

conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)

Expand Down Expand Up @@ -273,10 +277,16 @@ def apply_blend(current_latent):
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
cfg_denoised_callback(denoised_params)

if is_cfg_pp:
self.last_noise_uncond = x_out[-uncond.shape[0]:]
self.last_noise_uncond = torch.clone(self.last_noise_uncond)

if is_edit_model:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
elif skip_uncond:
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
elif is_cfg_pp:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale/12.5) # CFG++ scale of (0, 1) maps to (1.0, 12.5)
else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)

Expand Down
1 change: 1 addition & 0 deletions modules/sd_samplers_timesteps.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

samplers_timesteps = [
('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}),
('DDIM CFG++', sd_samplers_timesteps_impl.ddim_cfgpp, ['ddim_cfgpp'], {}),
('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}),
('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}),
]
Expand Down
37 changes: 37 additions & 0 deletions modules/sd_samplers_timesteps_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,43 @@ def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=
return x


@torch.no_grad()
def ddim_cfgpp(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0):
""" Implements CFG++: Manifold-constrained Classifier Free Guidance For Diffusion Models (2024).
Uses the unconditional noise prediction instead of the conditional noise to guide the denoising direction.
The CFG scale is divided by 12.5 to map CFG from [0.0, 12.5] to [0, 1.0].
"""
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
alphas = alphas_cumprod[timesteps]
alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(float64(x))
sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))

extra_args = {} if extra_args is None else extra_args
s_in = x.new_ones((x.shape[0]))
s_x = x.new_ones((x.shape[0], 1, 1, 1))
for i in tqdm.trange(len(timesteps) - 1, disable=disable):
index = len(timesteps) - 1 - i

e_t = model(x, timesteps[index].item() * s_in, **extra_args)
last_noise_uncond = model.last_noise_uncond

a_t = alphas[index].item() * s_x
a_prev = alphas_prev[index].item() * s_x
sigma_t = sigmas[index].item() * s_x
sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x

pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * last_noise_uncond
noise = sigma_t * k_diffusion.sampling.torch.randn_like(x)
x = a_prev.sqrt() * pred_x0 + dir_xt + noise

if callback is not None:
callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0})

return x


@torch.no_grad()
def plms(model, x, timesteps, extra_args=None, callback=None, disable=None):
alphas_cumprod = model.inner_model.inner_model.alphas_cumprod
Expand Down

0 comments on commit eb112c6

Please sign in to comment.