You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Describe the bug
hi, I copy code to use in my own project, and I find a Issuse when I use cross_entropy_loss.
In line 121
ifpred.size(1) ==1:
# For binary class segmentation, the shape of pred is# [N, 1, H, W] and that of label is [N, H, W].assertlabel.max() <=1, \
'For pred with shape [N, 1, H, W], its label must have at ' \
'most 2 classes'pred=pred.squeeze()
Should ' label.max() <= 1' mask out ignore_index? Since the ignore_index often set as 255.
Bug fix
ifpred.size(1) ==1:
# For binary class segmentation, the shape of pred is# [N, 1, H, W] and that of label is [N, H, W].assertlabel[label!=ignore_index].max() <=1, \
'For pred with shape [N, 1, H, W], its label must have at ' \
'most 2 classes'pred=pred.squeeze()
The text was updated successfully, but these errors were encountered:
Hi there, I faced another problem. In the case that labels are float32, label.max()<=1 returns false.
for example:
tensor(1.0000)<=1 returns false. That might be due to the internal representation of tensor(1.0000) which is not precisely equal to 1,
to solve the problem I change this line:
assert label.max() <= 1, "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes"
to this line:
assert torch.allclose(label.max(), torch.tensor(1, dtype=torch.float, device=label.device)), "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes"
Describe the bug
hi, I copy code to use in my own project, and I find a Issuse when I use cross_entropy_loss.
In line 121
Should ' label.max() <= 1' mask out ignore_index? Since the ignore_index often set as 255.
Bug fix
The text was updated successfully, but these errors were encountered: