Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] BUG in cross_entropy_loss.py #1525

Closed
Dawn-bin opened this issue Apr 28, 2022 · 2 comments
Closed

[BUG] BUG in cross_entropy_loss.py #1525

Dawn-bin opened this issue Apr 28, 2022 · 2 comments
Assignees

Comments

@Dawn-bin
Copy link
Contributor

Describe the bug
hi, I copy code to use in my own project, and I find a Issuse when I use cross_entropy_loss.
In line 121

if pred.size(1) == 1:
        # For binary class segmentation, the shape of pred is
        # [N, 1, H, W] and that of label is [N, H, W].
        assert label.max() <= 1, \
            'For pred with shape [N, 1, H, W], its label must have at ' \
            'most 2 classes'
        pred = pred.squeeze()

Should ' label.max() <= 1' mask out ignore_index? Since the ignore_index often set as 255.

Bug fix

if pred.size(1) == 1:
        # For binary class segmentation, the shape of pred is
        # [N, 1, H, W] and that of label is [N, H, W].
        assert label[label != ignore_index].max() <= 1, \
            'For pred with shape [N, 1, H, W], its label must have at ' \
            'most 2 classes'
        pred = pred.squeeze()
@MengzhangLI MengzhangLI self-assigned this Apr 28, 2022
@MengzhangLI
Copy link
Contributor

Hi, thanks for your suggestion. We think you are correct, could you make a PR to fix this problem? Thanks in advance.

aravind-h-v pushed a commit to aravind-h-v/mmsegmentation that referenced this issue Mar 27, 2023
@marz869
Copy link

marz869 commented Nov 27, 2024

Hi there, I faced another problem. In the case that labels are float32, label.max()<=1 returns false.
for example:
tensor(1.0000)<=1 returns false. That might be due to the internal representation of tensor(1.0000) which is not precisely equal to 1,
to solve the problem I change this line:
assert label.max() <= 1, "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes"
to this line:
assert torch.allclose(label.max(), torch.tensor(1, dtype=torch.float, device=label.device)), "For pred with shape [N, 1, H, W], its label must have at " "most 2 classes"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants