You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I need to train the condition stage model (frozen CLIP encoder) and I set self.cond_stage_trainable=true. But when I check the code of forward in DDPM_edit.py:
def forward(self, x, c, *args, **kwargs):
"""
input: x: target image feature c: dict(text token, condition image feature)
output: loss
"""
t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=self.device).long()
if self.model.conditioning_key is not None:
assert c is not None
if self.cond_stage_trainable: # default false
c = self.get_learned_conditioning(c)
if self.shorten_cond_schedule: # TODO: drop this option
tc = self.cond_ids[t].to(self.device)
c = self.q_sample(x_start=c, t=tc, noise=torch.randn_like(c.float()))
return self.p_losses(x, c, t, *args, **kwargs)
I feel confused about why c needed to be encoded again c = self.get_learned_conditioning(c), since c has been encoded in function get_input.
and you can see the original get_input function in DDPM.py by CompVis:
if not self.cond_stage_trainable or force_c_encode:
if isinstance(xc, dict) or isinstance(xc, list):
# import pudb; pudb.set_trace()
c = self.get_learned_conditioning(xc)
else:
c = self.get_learned_conditioning(xc.to(self.device))
else:
c = xc
They handle this well. I think we can directly remove if self.model.conditioning_key is not None: branch in forward function of DDPM_edit.py.
Does anyone have some opinions? Thanks in advance.
The text was updated successfully, but these errors were encountered:
Hello all, Thanks for your great work first!
I need to train the condition stage model (frozen CLIP encoder) and I set self.cond_stage_trainable=true. But when I check the code of forward in DDPM_edit.py:
I feel confused about why c needed to be encoded again
c = self.get_learned_conditioning(c)
, since c has been encoded in functionget_input
.and you can see the original
get_input
function in DDPM.py by CompVis:They handle this well. I think we can directly remove
if self.model.conditioning_key is not None:
branch in forward function of DDPM_edit.py.Does anyone have some opinions? Thanks in advance.
The text was updated successfully, but these errors were encountered: