Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix the bug that when all pixels in an image is ignored, the ac… #1336

Merged
merged 5 commits into from
Mar 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions mmseg/models/losses/accuracy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn


Expand Down Expand Up @@ -46,10 +47,13 @@ def accuracy(pred, target, topk=1, thresh=None, ignore_index=None):
correct = correct & (pred_value > thresh).t()
correct = correct[:, target != ignore_index]
res = []
eps = torch.finfo(torch.float32).eps
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(
correct_k.mul_(100.0 / target[target != ignore_index].numel()))
# Avoid causing ZeroDivisionError when all pixels
# of an image are ignored
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps
total_num = target[target != ignore_index].numel() + eps
res.append(correct_k.mul_(100.0 / total_num))
return res[0] if return_single else res


Expand Down
22 changes: 14 additions & 8 deletions tests/test_models/test_losses/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,50 +56,56 @@ def test_accuracy():
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=None)
acc = accuracy(pred, true_label)
assert acc.item() == 100
assert torch.allclose(acc, torch.tensor(100.0))

# test for ignore_index with a wrong prediction of that index
true_label = torch.Tensor([2, 3, 1, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=1)
acc = accuracy(pred, true_label)
assert acc.item() == 100
assert torch.allclose(acc, torch.tensor(100.0))

# test for ignore_index 1 with a wrong prediction of other index
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=1)
acc = accuracy(pred, true_label)
assert acc.item() == 75
assert torch.allclose(acc, torch.tensor(75.0))

# test for ignore_index 4 with a wrong prediction of other index
true_label = torch.Tensor([2, 0, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=4)
acc = accuracy(pred, true_label)
assert acc.item() == 80
assert torch.allclose(acc, torch.tensor(80.0))

# test for ignoring all the pixels
true_label = torch.Tensor([2, 2, 2, 2, 2]).long()
accuracy = Accuracy(topk=1, ignore_index=2)
acc = accuracy(pred, true_label)
assert torch.allclose(acc, torch.tensor(100.0))

# test for top1
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
accuracy = Accuracy(topk=1)
acc = accuracy(pred, true_label)
assert acc.item() == 100
assert torch.allclose(acc, torch.tensor(100.0))

# test for top1 with score thresh=0.8
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
accuracy = Accuracy(topk=1, thresh=0.8)
acc = accuracy(pred, true_label)
assert acc.item() == 40
assert torch.allclose(acc, torch.tensor(40.0))

# test for top2
accuracy = Accuracy(topk=2)
label = torch.Tensor([3, 2, 0, 0, 2]).long()
acc = accuracy(pred, label)
assert acc.item() == 100
assert torch.allclose(acc, torch.tensor(100.0))

# test for both top1 and top2
accuracy = Accuracy(topk=(1, 2))
true_label = torch.Tensor([2, 3, 0, 1, 2]).long()
acc = accuracy(pred, true_label)
for a in acc:
assert a.item() == 100
assert torch.allclose(a, torch.tensor(100.0))

# topk is larger than pred class number
with pytest.raises(AssertionError):
Expand Down