Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about DDPM and DDIM sampling. #13

Open
e4s2022 opened this issue Dec 13, 2022 · 1 comment
Open

Question about DDPM and DDIM sampling. #13

e4s2022 opened this issue Dec 13, 2022 · 1 comment

Comments

@e4s2022
Copy link

e4s2022 commented Dec 13, 2022

Hi, thanks for sharing your excellent work!

I just walked through the code base and noticed that during sampling you used timestamp t from 0 to 999 (see here. I think in the reversed pass, we should start from 999 till 0. I'm a little confused about this.

Another question is, what does the denoise option mean for the last sampling step? please check here.

These two questions can be raised either for the DDPM or DDIM sampler. Really appreciate your explanation.

@JunyaoHu
Copy link

JunyaoHu commented Mar 21, 2023

您好,我的看法是这样的。

作者使用的是基于朗之万动力学NCSN扩散模型。原论文在设置参数的时候,开始参数大于结束参数。betas按列表顺序由大到小。在降噪过程中,betas应该是由大到小,列表索引应该是0-999。X_T是真实图像。

image

image

sigma_begin: 0.02
sigma_end: 0.0001

elif config.model.sigma_dist == 'linear':
return torch.linspace(config.model.sigma_begin, config.model.sigma_end,
T).to(config.device)

if self.schedule == 'linear':
self.register_buffer('betas', get_sigmas(config))
self.register_buffer('alphas', torch.cumprod(1 - self.betas.flip(0), 0).flip(0))
self.register_buffer('alphas_prev', torch.cat([self.alphas[1:], torch.tensor([1.0]).to(self.alphas)]))

for i, step in enumerate(steps):

作为对比:
在DDPM原论文中,开始参数小于结束参数。betas按列表顺序由小到大。降噪过程时间步从列表索引999到0。X_0是真实图像。

image

image

https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/scripts/run_celebahq.py#L132-L137

def train(
    exp_name, tpu_name, bucket_name_prefix, model_name='unet2d16b2c112244', dataset='celebahq256',
    optimizer='adam', total_bs=64, grad_clip=1., lr=0.00002, warmup=5000,
    num_diffusion_timesteps=1000, beta_start=0.0001, beta_end=0.02, beta_schedule='linear', loss_type='noisepred',
    dropout=0.0, randflip=1, block_size=1,
    tfds_data_dir='tensorflow_datasets', log_dir='logs'

https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L26-L27

  elif beta_schedule == 'linear':
    betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)

https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L205-L217

    i_0 = tf.constant(self.num_timesteps - 1, dtype=tf.int32)
    img_0 = noise_fn(shape=shape, dtype=tf.float32)
    _, img_final = tf.while_loop(
      cond=lambda i_, _: tf.greater_equal(i_, 0),
      body=lambda i_, img_: [
        i_ - 1,
        self.p_sample(
          denoise_fn=denoise_fn, x=img_, t=tf.fill([shape[0]], i_), noise_fn=noise_fn, return_pred_xstart=False)
      ],
      loop_vars=[i_0, img_0],
      shape_invariants=[i_0.shape, img_0.shape],
      back_prop=False
    )

LDM latent diffusion类似。
LDM DDPM开始参数小于结束参数。betas按列表顺序由小到大。降噪过程时间步从列表索引999到0。X_0是真实图像。

https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/configs/latent-diffusion/cin256-v2.yaml#L5-L6

    linear_start: 0.0015
    linear_end: 0.0195

https://github.com/CompVis/latent-diffusion/blob/171cf29fb54afe048b03ec73da8abb9d102d0614/ldm/modules/diffusionmodules/util.py#L22-L25

    if schedule == "linear":
        betas = (
                torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
        )

https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/models/diffusion/ddpm.py#L258-L260

        for i in tqdm(reversed(range(0, self.num_timesteps)), desc='Sampling t', total=self.num_timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long),
                                clip_denoised=self.clip_denoised)

LDM DDIM

https://github.com/CompVis/latent-diffusion/blob/a506df5756472e2ebaf9078affdde2c4f1502cd4/ldm/models/diffusion/ddim.py#L133-L160

        time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
        total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
        print(f"Running DDIM Sampling with {total_steps} timesteps")


        iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)


        for i, step in enumerate(iterator):
            index = total_steps - i - 1
            ts = torch.full((b,), step, device=device, dtype=torch.long)


            if mask is not None:
                assert x0 is not None
                img_orig = self.model.q_sample(x0, ts)  # TODO: deterministic forward pass?
                img = img_orig * mask + (1. - mask) * img


            outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
                                      quantize_denoised=quantize_denoised, temperature=temperature,
                                      noise_dropout=noise_dropout, score_corrector=score_corrector,
                                      corrector_kwargs=corrector_kwargs,
                                      unconditional_guidance_scale=unconditional_guidance_scale,
                                      unconditional_conditioning=unconditional_conditioning)
            img, pred_x0 = outs
            if callback: callback(i)
            if img_callback: img_callback(pred_x0, i)


            if index % log_every_t == 0 or index == total_steps - 1:
                intermediates['x_inter'].append(img)
                intermediates['pred_x0'].append(pred_x0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants