diff --git a/mmseg/models/decode_heads/decode_head.py b/mmseg/models/decode_heads/decode_head.py index 1443a81da2..d08b1d0b60 100644 --- a/mmseg/models/decode_heads/decode_head.py +++ b/mmseg/models/decode_heads/decode_head.py @@ -261,5 +261,6 @@ def losses(self, seg_logit, seg_label): weight=seg_weight, ignore_index=self.ignore_index) - loss['acc_seg'] = accuracy(seg_logit, seg_label) + loss['acc_seg'] = accuracy( + seg_logit, seg_label, ignore_index=self.ignore_index) return loss diff --git a/mmseg/models/decode_heads/point_head.py b/mmseg/models/decode_heads/point_head.py index eb54bbcfe6..5e605271c7 100644 --- a/mmseg/models/decode_heads/point_head.py +++ b/mmseg/models/decode_heads/point_head.py @@ -264,7 +264,8 @@ def losses(self, point_logits, point_label): loss['point' + loss_module.loss_name] = loss_module( point_logits, point_label, ignore_index=self.ignore_index) - loss['acc_point'] = accuracy(point_logits, point_label) + loss['acc_point'] = accuracy( + point_logits, point_label, ignore_index=self.ignore_index) return loss def get_points_train(self, seg_logits, uncertainty_func, cfg): diff --git a/mmseg/models/losses/accuracy.py b/mmseg/models/losses/accuracy.py index f2cd16b7f9..7cd15e222f 100644 --- a/mmseg/models/losses/accuracy.py +++ b/mmseg/models/losses/accuracy.py @@ -2,12 +2,13 @@ import torch.nn as nn -def accuracy(pred, target, topk=1, thresh=None): +def accuracy(pred, target, topk=1, thresh=None, ignore_index=None): """Calculate accuracy according to the prediction and target. Args: pred (torch.Tensor): The model prediction, shape (N, num_class, ...) target (torch.Tensor): The target of each prediction, shape (N, , ...) + ignore_index (int | None): The label index to be ignored. Default: None topk (int | tuple[int], optional): If the predictions in ``topk`` matches the target, the predictions will be regarded as correct ones. Defaults to 1. @@ -43,17 +44,19 @@ def accuracy(pred, target, topk=1, thresh=None): if thresh is not None: # Only prediction values larger than thresh are counted as correct correct = correct & (pred_value > thresh).t() + correct = correct[:, target != ignore_index] res = [] for k in topk: correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) - res.append(correct_k.mul_(100.0 / target.numel())) + res.append( + correct_k.mul_(100.0 / target[target != ignore_index].numel())) return res[0] if return_single else res class Accuracy(nn.Module): """Accuracy calculation module.""" - def __init__(self, topk=(1, ), thresh=None): + def __init__(self, topk=(1, ), thresh=None, ignore_index=None): """Module to calculate the accuracy. Args: @@ -65,6 +68,7 @@ def __init__(self, topk=(1, ), thresh=None): super().__init__() self.topk = topk self.thresh = thresh + self.ignore_index = ignore_index def forward(self, pred, target): """Forward function to calculate accuracy. @@ -76,4 +80,5 @@ def forward(self, pred, target): Returns: tuple[float]: The accuracies under different topk criterions. """ - return accuracy(pred, target, self.topk, self.thresh) + return accuracy(pred, target, self.topk, self.thresh, + self.ignore_index) diff --git a/tests/test_models/test_losses/test_utils.py b/tests/test_models/test_losses/test_utils.py index 1d94387ed7..ac5c6664f7 100644 --- a/tests/test_models/test_losses/test_utils.py +++ b/tests/test_models/test_losses/test_utils.py @@ -52,6 +52,30 @@ def test_accuracy(): pred = torch.Tensor([[0.2, 0.3, 0.6, 0.5], [0.1, 0.1, 0.2, 0.6], [0.9, 0.0, 0.0, 0.1], [0.4, 0.7, 0.1, 0.1], [0.0, 0.0, 0.99, 0]]) + # test for ignore_index + true_label = torch.Tensor([2, 3, 0, 1, 2]).long() + accuracy = Accuracy(topk=1, ignore_index=None) + acc = accuracy(pred, true_label) + assert acc.item() == 100 + + # test for ignore_index with a wrong prediction of that index + true_label = torch.Tensor([2, 3, 1, 1, 2]).long() + accuracy = Accuracy(topk=1, ignore_index=1) + acc = accuracy(pred, true_label) + assert acc.item() == 100 + + # test for ignore_index 1 with a wrong prediction of other index + true_label = torch.Tensor([2, 0, 0, 1, 2]).long() + accuracy = Accuracy(topk=1, ignore_index=1) + acc = accuracy(pred, true_label) + assert acc.item() == 75 + + # test for ignore_index 4 with a wrong prediction of other index + true_label = torch.Tensor([2, 0, 0, 1, 2]).long() + accuracy = Accuracy(topk=1, ignore_index=4) + acc = accuracy(pred, true_label) + assert acc.item() == 80 + # test for top1 true_label = torch.Tensor([2, 3, 0, 1, 2]).long() accuracy = Accuracy(topk=1)