Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Masking #18

Open
david-waterworth opened this issue Oct 25, 2021 · 0 comments
Open

Masking #18

david-waterworth opened this issue Oct 25, 2021 · 0 comments

Comments

@david-waterworth
Copy link

david-waterworth commented Oct 25, 2021

I tried replacing BCE loss with DICE and my model wouldn't converge. When I looked closer I noticed that whilst the input and target are flattened, the mask isn't. So if you pass a mask that is the same shape as the target, then the multiplication flat_input * mask unflattens flat_input

def _binary_class(self, input, target, mask=None):
    flat_input = input.view(-1)
    flat_target = target.view(-1).float()
    flat_input = torch.sigmoid(flat_input) if self.with_logits else flat_input

    if mask is not None:
        mask = mask.float()
        flat_input = flat_input * mask
        flat_target = flat_target * mask
    else:
        mask = torch.ones_like(target)

I made the following change and my model started converging immediately

    if mask is not None:
        mask = mask.float()
        flat_input = flat_input * mask.view(-1)
        flat_target = flat_target * mask.view(-1)
    else:
        mask = torch.ones_like(target)

Although I think a better fix is to actually apply the mask rather than mask out the masked inputs/targets ie.

        mask = mask.view(-1)
        flat_input = flat_input[mask]
        flat_target = flat_target[mask]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant