You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I tried replacing BCE loss with DICE and my model wouldn't converge. When I looked closer I noticed that whilst the input and target are flattened, the mask isn't. So if you pass a mask that is the same shape as the target, then the multiplication flat_input * mask unflattens flat_input
def _binary_class(self, input, target, mask=None):
flat_input = input.view(-1)
flat_target = target.view(-1).float()
flat_input = torch.sigmoid(flat_input) if self.with_logits else flat_input
if mask is not None:
mask = mask.float()
flat_input = flat_input * mask
flat_target = flat_target * mask
else:
mask = torch.ones_like(target)
I made the following change and my model started converging immediately
if mask is not None:
mask = mask.float()
flat_input = flat_input * mask.view(-1)
flat_target = flat_target * mask.view(-1)
else:
mask = torch.ones_like(target)
Although I think a better fix is to actually apply the mask rather than mask out the masked inputs/targets ie.
I tried replacing BCE loss with DICE and my model wouldn't converge. When I looked closer I noticed that whilst the input and target are flattened, the mask isn't. So if you pass a mask that is the same shape as the target, then the multiplication
flat_input * mask
unflattensflat_input
I made the following change and my model started converging immediately
Although I think a better fix is to actually apply the mask rather than mask out the masked inputs/targets ie.
The text was updated successfully, but these errors were encountered: