From e8cfb9832b05a1f7744f2e66efe36d571c388a8a Mon Sep 17 00:00:00 2001 From: Rockey <41846794+RockeyCoss@users.noreply.github.com> Date: Wed, 9 Mar 2022 13:20:46 +0800 Subject: [PATCH] =?UTF-8?q?[Fix]=20Fix=20the=20bug=20that=20when=20all=20p?= =?UTF-8?q?ixels=20in=20an=20image=20is=20ignored,=20the=20ac=E2=80=A6=20(?= =?UTF-8?q?#1336)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [Fix] Fix the bug that when all pixels in an image is ignored, the accuracy calculation raises ZeroDivisionError * use eps * all close * add ignore test * add eps --- mmseg/models/losses/accuracy.py | 10 +++++++--- tests/test_models/test_losses/test_utils.py | 22 +++++++++++++-------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/mmseg/models/losses/accuracy.py b/mmseg/models/losses/accuracy.py index 7cd15e222f1..28d55c4e459 100644 --- a/mmseg/models/losses/accuracy.py +++ b/mmseg/models/losses/accuracy.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import torch import torch.nn as nn @@ -46,10 +47,13 @@ def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): correct = correct & (pred_value > thresh).t() correct = correct[:, target != ignore_index] res = [] + eps = torch.finfo(torch.float32).eps for k in topk: - correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) - res.append( - correct_k.mul_(100.0 / target[target != ignore_index].numel())) + # Avoid causing ZeroDivisionError when all pixels + # of an image are ignored + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + eps + total_num = target[target != ignore_index].numel() + eps + res.append(correct_k.mul_(100.0 / total_num)) return res[0] if return_single else res diff --git a/tests/test_models/test_losses/test_utils.py b/tests/test_models/test_losses/test_utils.py index ac5c6664f72..ab9927fe1cf 100644 --- a/tests/test_models/test_losses/test_utils.py +++ b/tests/test_models/test_losses/test_utils.py @@ -56,50 +56,56 @@ def test_accuracy(): true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=None) acc = accuracy(pred, true_label) - assert acc.item() == 100 + assert torch.allclose(acc, torch.tensor(100.0)) # test for ignore_index with a wrong prediction of that index true_label = torch.Tensor([2, 3, 1, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=1) acc = accuracy(pred, true_label) - assert acc.item() == 100 + assert torch.allclose(acc, torch.tensor(100.0)) # test for ignore_index 1 with a wrong prediction of other index true_label = torch.Tensor([2, 0, 0, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=1) acc = accuracy(pred, true_label) - assert acc.item() == 75 + assert torch.allclose(acc, torch.tensor(75.0)) # test for ignore_index 4 with a wrong prediction of other index true_label = torch.Tensor([2, 0, 0, 1, 2]).long() accuracy = Accuracy(topk=1, ignore_index=4) acc = accuracy(pred, true_label) - assert acc.item() == 80 + assert torch.allclose(acc, torch.tensor(80.0)) + + # test for ignoring all the pixels + true_label = torch.Tensor([2, 2, 2, 2, 2]).long() + accuracy = Accuracy(topk=1, ignore_index=2) + acc = accuracy(pred, true_label) + assert torch.allclose(acc, torch.tensor(100.0)) # test for top1 true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1) acc = accuracy(pred, true_label) - assert acc.item() == 100 + assert torch.allclose(acc, torch.tensor(100.0)) # test for top1 with score thresh=0.8 true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1, thresh=0.8) acc = accuracy(pred, true_label) - assert acc.item() == 40 + assert torch.allclose(acc, torch.tensor(40.0)) # test for top2 accuracy = Accuracy(topk=2) label = torch.Tensor([3, 2, 0, 0, 2]).long() acc = accuracy(pred, label) - assert acc.item() == 100 + assert torch.allclose(acc, torch.tensor(100.0)) # test for both top1 and top2 accuracy = Accuracy(topk=(1, 2)) true_label = torch.Tensor([2, 3, 0, 1, 2]).long() acc = accuracy(pred, true_label) for a in acc: - assert a.item() == 100 + assert torch.allclose(a, torch.tensor(100.0)) # topk is larger than pred class number with pytest.raises(AssertionError):