Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unnecessary code in forward function of DDPM_edit.py #141

Open
LIKP0 opened this issue Nov 1, 2024 · 0 comments
Open

Remove unnecessary code in forward function of DDPM_edit.py #141

LIKP0 opened this issue Nov 1, 2024 · 0 comments

Comments

@LIKP0
Copy link

LIKP0 commented Nov 1, 2024

Hello all, Thanks for your great work first!

I need to train the condition stage model (frozen CLIP encoder) and I set self.cond_stage_trainable=true. But when I check the code of forward in DDPM_edit.py:

def forward(self, x, c, *args, **kwargs):
    """
    input: x: target image feature c: dict(text token, condition image feature)
    output: loss
    """
    t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
    if self.model.conditioning_key is not None:
        assert c is not None
        if self.cond_stage_trainable:  # default false
            c = self.get_learned_conditioning(c)
        if self.shorten_cond_schedule:  # TODO: drop this option
            tc = self.cond_ids[t].to(self.device)
            c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
    return self.p_losses(x, c, t, *args, **kwargs)

I feel confused about why c needed to be encoded again c = self.get_learned_conditioning(c), since c has been encoded in function get_input.

and you can see the original get_input function in DDPM.py by CompVis:

        if not self.cond_stage_trainable or force_c_encode:
            if isinstance(xc, dict) or isinstance(xc, list):
                # import pudb; pudb.set_trace()
                c = self.get_learned_conditioning(xc)
            else:
                c = self.get_learned_conditioning(xc.to(self.device))
        else:
            c = xc

They handle this well. I think we can directly remove if self.model.conditioning_key is not None: branch in forward function of DDPM_edit.py.

Does anyone have some opinions? Thanks in advance.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant