Skip to content
This repository has been archived by the owner on Feb 14, 2024. It is now read-only.

Is there a worng understand in total variation? #27

Open
GuardSkill opened this issue Nov 17, 2018 · 5 comments
Open

Is there a worng understand in total variation? #27

GuardSkill opened this issue Nov 17, 2018 · 5 comments

Comments

@GuardSkill
Copy link

GuardSkill commented Nov 17, 2018

I find this does not conform to the original paper’s method, I think the sum of the abs value should be taken into the Loss(tv), and the tv loss is not the global difference of the whole picture, it just around the hole areas (P is the region of 1-pixel dilation of the hole region).

def total_variation_loss(image):
    # shift one pixel and get difference (for both x and y direction)
    loss = torch.mean(torch.abs(image[:, :, :, :-1] - image[:, :, :, 1:])) + \
        torch.mean(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :]))
    return loss
@GuardSkill
Copy link
Author

GuardSkill commented Nov 17, 2018

maybe should be these?``
def total_variation_loss(image,mask):
hole_mask = 1-mask
loss = torch.sum(torch.abs(hole_mask[:, :, :, :-1](image[:, :, :, 1:] - image[:, :, :, :-1]))) +
torch.sum(hole_mask[:, :, 👎, :]
(torch.abs(image[:, :, :-1, :] - image[:, :, 1:, :])))
return loss

@GuardSkill GuardSkill changed the title In net.py 85 Line, it there wrong with original paper? Sorry,I have a wrong understand Nov 17, 2018
@GuardSkill GuardSkill changed the title Sorry,I have a wrong understand There is a worng understand in total variation Nov 20, 2018
@GuardSkill GuardSkill reopened this Nov 20, 2018
@GuardSkill
Copy link
Author

GuardSkill commented Nov 20, 2018

More seriously, it should be these code rather than above (Above code didn't consider the uppest/leftest dilated pixel minus operation)

def dialation_holes(hole_mask):
    b, ch, h, w = hole_mask.shape
    dilation_conv = nn.Conv2d(ch, ch, 3, padding=1, bias=False).to(device)
    torch.nn.init.constant_(dilation_conv.weight, 1.0)
    with torch.no_grad():
        output_mask = dilation_conv(hole_mask)
    updated_holes = output_mask != 0
    return updated_holes.float()

def total_variation_loss(image,mask):
    hole_mask = 1-mask
    dilated_holes=dialation_holes(hole_mask)
    colomns_in_Pset=dilated_holes[:, :, :, 1:] * dilated_holes[:, :, :, :-1]
    rows_in_Pset=dilated_holes[:, :, 1:, :] * dilated_holes[:, :, :-1:, :]
    loss = torch.sum(torch.abs(colomns_in_Pset*(image[:, :, :, 1:] - image[:, :, :, :-1]))) + \
        torch.sum(torch.abs(rows_in_Pset*(image[:, :, :1 :] - image[:, :, -1:, :])))
    return loss

@GuardSkill GuardSkill changed the title There is a worng understand in total variation Is there a worng understand in total variation? Nov 22, 2018
@Daisy007girl
Copy link

Have you tried the code which you thought it should be? Have it brought any improvement to the result compared to the github author's ?

@Xavier31
Copy link

Xavier31 commented Dec 7, 2020

Hi !
@GuardSkill shouldn't it be mean instead of sum in the total_variation_loss function ?

loss = torch.mean(torch.abs(colomns_in_Pset*(image[:, :, :, 1:] - image[:, :, :, :-1]))) + \ torch.mean(torch.abs(rows_in_Pset*(image[:, :, :1 :] - image[:, :, -1:, :])))

@fgiobergia
Copy link

I would argue that, while not the same exact loss as the one proposed in the paper (L_tv), the total_variation_loss() implemented here should behave in just the same way.

Both total_variation_loss() and L_tv are computed on I_comp (output_comp in the code) and not I_out (output), which contains:

  • the ground truth image I_gt (input) in the unmasked part
  • the reconstructed image I_out in in the masked part

Since the ground truth image does not change with I_out, it means that all 1-pixel shifts outside of the mask will always result in the same total variation, outside of the masked region. Inside of the masked region (as well as around the 1-pixel dilation of the mask) the TV loss will instead depend on I_out.

As such, the loss implemented here is L_tv + constant: the two functions thus share the same gradient.

It also seems to me that the current implementation is slightly more efficient, as it does not require computing the dilated mask, nor mask the image.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants