Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 858 ignore multiple labels segmentation metrics support #1177

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 152 additions & 4 deletions src/super_gradients/training/metrics/segmentation_metrics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import typing

import numpy as np
import torch
import torchmetrics
from torchmetrics import Metric
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Union
from torchmetrics.utilities.distributed import reduce
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -122,6 +124,35 @@ def intersection_and_union(im_pred, im_lab, num_class):
return area_inter, area_union


def _map_ignored_inds(target: torch.Tensor, ignore_index_list: List[int], unfiltered_num_classes: int) -> torch.Tensor:
"""
Creaetes a copy of target, mapping indices in range(unfiltered_num_classes) to range(unfiltered_num_classes-len(
ignore_index_list)+1). Indices in ignore_index_list are being mapped to 0, which can later on be used as
"ignore_index".

Example:
>>>_map_ignored_inds(torch.tensor([0,1,2,3,4,5,6]), ignore_index_list=[3,5,1], unfiltered_num_classes=7)
>>> tensor([1, 0, 2, 0, 3, 0, 4])



:param target: torch.Tensor, tensor to perform the mapping on.
:param ignore_index_list: List[int], list of indices to map to 0 in the output tensor.
:param unfiltered_num_classes: int, Total number of possible class indices in target.

:return: mapped tensor as described above.
"""
target_copy = torch.zeros_like(target)
all_unfiltered_classes = list(range(unfiltered_num_classes))
filtered_classes = [i for i in all_unfiltered_classes if i not in ignore_index_list]
for mapped_idx in range(len(filtered_classes)):
cls_to_map = filtered_classes[mapped_idx]
map_val = mapped_idx + 1
target_copy[target == cls_to_map] = map_val

return target_copy


class AbstractMetricsArgsPrepFn(ABC):
"""
Abstract preprocess metrics arguments class.
Expand Down Expand Up @@ -164,7 +195,26 @@ def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten

@register_metric(Metrics.PIXEL_ACCURACY)
class PixelAccuracy(Metric):
def __init__(self, ignore_label=-100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None):
"""
Pixel Accuracy

Args:
ignore_label: Optional[Union[int, List[int]]], specifying a target class(es) to ignore.
If given, this class index does not contribute to the returned score, regardless of reduction method.
Has no effect if given an int that is not in the range [0, num_classes-1].
By default, no index is ignored, and all classes are used.
IMPORTANT: reduction="none" alongside with a list of ignored indices is not supported and will raise an error.
reduction: a method to reduce metric score over labels:

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

metrics_args_prep_fn: Callable, inputs preprocess function applied on preds, target before updating metrics.
By default set to PreprocessSegmentationMetricsArgs(apply_arg_max=True)
"""

def __init__(self, ignore_label: Union[int, List[int]] = -100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.ignore_label = ignore_label
self.greater_is_better = True
Expand All @@ -174,46 +224,136 @@ def __init__(self, ignore_label=-100, dist_sync_on_step=False, metrics_args_prep

def update(self, preds: torch.Tensor, target: torch.Tensor):
predict, target = self.metrics_args_prep_fn(preds, target)
labeled_mask = self._handle_multiple_ignored_inds(target)

labeled_mask = target.ne(self.ignore_label)
pixel_labeled = torch.sum(labeled_mask)
pixel_correct = torch.sum((predict == target) * labeled_mask)
self.total_correct += pixel_correct
self.total_label += pixel_labeled

def _handle_multiple_ignored_inds(self, target):
if isinstance(self.ignore_label, typing.Iterable):
evaluated_classes_mask = torch.ones_like(target)
for ignored_label in self.ignore_label:
evaluated_classes_mask = evaluated_classes_mask.masked_fill(target.eq(ignored_label), 0)
else:
evaluated_classes_mask = target.ne(self.ignore_label)

return evaluated_classes_mask

def compute(self):
_total_correct = self.total_correct.cpu().detach().numpy().astype("int64")
_total_label = self.total_label.cpu().detach().numpy().astype("int64")
pix_acc = np.float64(1.0) * _total_correct / (np.spacing(1, dtype=np.float64) + _total_label)
return pix_acc


def _handle_multiple_ignored_inds(ignore_index: Union[int, List[int]], num_classes: int):
"""
Helper method for variable assignment, prior to the

super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold)

call in segmentation metrics inheriting from torchmetrics.JaccardIndex.
When ignore_index is list, the num_classes being passed to the torchmetrics.JaccardIndex c'tor is set to be the one after
mapping of the ignored indices in ignore_index_list to 0. Hence, we set:
ignore_index=0,
And since we map all of the ignored indices to 0, it is if we removed them and introduces a new index:
num_classes = num_classes - len(ignore_index_list) +1
Unfiltered num_classes is used in .update() for mapping of the original indice values.
Sets ignore_index to 0
:param ignore_index: list or single int representing the class ind(ices) to ignore.
:param num_classes: int, num_classes (original, before mapping) being passed to segmentation metric classesץ
:return:ignore_index, ignore_index_list, num_classes, unfiltered_num_classesignore_index, ignore_index_list, num_classes, unfiltered_num_classes
"""
if isinstance(ignore_index, typing.Iterable):
ignore_index_list = ignore_index
unfiltered_num_classes = num_classes
num_classes = num_classes - len(ignore_index_list) + 1
ignore_index = 0
else:
unfiltered_num_classes = num_classes
ignore_index_list = None
return ignore_index, ignore_index_list, num_classes, unfiltered_num_classes


@register_metric(Metrics.IOU)
class IoU(torchmetrics.JaccardIndex):
"""
IoU Metric

Args:
num_classes: Number of classes in the dataset.
ignore_index: Optional[Union[int, List[int]]], specifying a target class(es) to ignore.
If given, this class index does not contribute to the returned score, regardless of reduction method.
Has no effect if given an int that is not in the range [0, num_classes-1].
By default, no index is ignored, and all classes are used.
IMPORTANT: reduction="none" alongside with a list of ignored indices is not supported and will raise an error.
threshold: Threshold value for binary or multi-label probabilities.
reduction: a method to reduce metric score over labels:

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

metrics_args_prep_fn: Callable, inputs preprocess function applied on preds, target before updating metrics.
By default set to PreprocessSegmentationMetricsArgs(apply_arg_max=True)
"""

def __init__(
self,
num_classes: int,
dist_sync_on_step: bool = False,
ignore_index: Optional[int] = None,
ignore_index: Optional[Union[int, List[int]]] = None,
reduction: str = "elementwise_mean",
threshold: float = 0.5,
metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None,
):

if num_classes <= 1:
raise ValueError(f"IoU class only for multi-class usage! For binary usage, please call {BinaryIOU.__name__}")
if isinstance(ignore_index, typing.Iterable) and reduction == "none":
raise ValueError("passing multiple ignore indices ")
ignore_index, ignore_index_list, num_classes, unfiltered_num_classes = _handle_multiple_ignored_inds(ignore_index, num_classes)

super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold)

self.unfiltered_num_classes = unfiltered_num_classes
self.ignore_index_list = ignore_index_list
self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True)
self.greater_is_better = True

def update(self, preds, target: torch.Tensor):
preds, target = self.metrics_args_prep_fn(preds, target)
if self.ignore_index_list is not None:
target = _map_ignored_inds(target, self.ignore_index_list, self.unfiltered_num_classes)
preds = _map_ignored_inds(preds, self.ignore_index_list, self.unfiltered_num_classes)
super().update(preds=preds, target=target)


@register_metric(Metrics.DICE)
class Dice(torchmetrics.JaccardIndex):
"""
Dice Coefficient Metric

Args:
num_classes: Number of classes in the dataset.
ignore_index: Optional[Union[int, List[int]]], specifying a target class(es) to ignore.
If given, this class index does not contribute to the returned score, regardless of reduction method.
Has no effect if given an int that is not in the range [0, num_classes-1].
By default, no index is ignored, and all classes are used.
IMPORTANT: reduction="none" alongside with a list of ignored indices is not supported and will raise an error.
threshold: Threshold value for binary or multi-label probabilities.
reduction: a method to reduce metric score over labels:

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

metrics_args_prep_fn: Callable, inputs preprocess function applied on preds, target before updating metrics.
By default set to PreprocessSegmentationMetricsArgs(apply_arg_max=True)
"""

def __init__(
self,
num_classes: int,
Expand All @@ -227,12 +367,20 @@ def __init__(
if num_classes <= 1:
raise ValueError(f"Dice class only for multi-class usage! For binary usage, please call {BinaryDice.__name__}")

ignore_index, ignore_index_list, num_classes, unfiltered_num_classes = _handle_multiple_ignored_inds(ignore_index, num_classes)

super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold)

self.ignore_index_list = ignore_index_list
self.unfiltered_num_classes = unfiltered_num_classes
self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True)
self.greater_is_better = True

def update(self, preds, target: torch.Tensor):
preds, target = self.metrics_args_prep_fn(preds, target)
if self.ignore_index_list is not None:
target = _map_ignored_inds(target, self.ignore_index_list, self.unfiltered_num_classes)
preds = _map_ignored_inds(preds, self.ignore_index_list, self.unfiltered_num_classes)
super().update(preds=preds, target=target)

def compute(self) -> torch.Tensor:
Expand Down
2 changes: 2 additions & 0 deletions tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tests.unit_tests.load_checkpoint_test import LoadCheckpointTest
from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest
from tests.unit_tests.max_batches_loop_break_test import MaxBatchesLoopBreakTest
from tests.unit_tests.multiple_ignore_indices_segmentation_metrics_test import TestSegmentationMetricsMultipleIgnored
from tests.unit_tests.phase_delegates_test import ContextMethodsTest
from tests.unit_tests.pose_estimation_dataset_test import TestPoseEstimationDataset
from tests.unit_tests.preprocessing_unit_test import PreprocessingUnitTest
Expand Down Expand Up @@ -141,6 +142,7 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLONAS))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DeprecationsUnitTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestMinSamplesSingleNode))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestSegmentationMetricsMultipleIgnored))

def _add_modules_to_end_to_end_tests_suite(self):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import unittest

import torch

from super_gradients.training.metrics import IoU, PixelAccuracy, Dice


class TestSegmentationMetricsMultipleIgnored(unittest.TestCase):
def test_iou_with_multiple_ignored_classes_and_absent_score(self):
metric_multi_ignored = IoU(num_classes=5, ignore_index=[3, 1, 2])
target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]])
pred = torch.zeros((1, 5, 6))
pred[:, 4] = 1

# preds after onehot -> [4,4,4,4,4,4]
# (1 + 0)/2 : 1.0 for class 4 score and 0 for absent score for class 0
self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5)

def test_iou_with_multiple_ignored_classes_no_absent_score(self):
metric_multi_ignored = IoU(num_classes=5, ignore_index=[3, 1, 2])
target_multi_ignored = torch.tensor([[3, 1, 2, 0, 4, 4]])
pred = torch.zeros((1, 5, 6))
pred[:, 4] = 1
pred[0, 0, 3] = 2

# preds after onehot -> [4,4,4,0,4,4]
# (1 + 1)/2 : 1.0 for class 4 score and 1 for class 0
self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1)

def test_dice_with_multiple_ignored_classes_and_absent_score(self):
metric_multi_ignored = Dice(num_classes=5, ignore_index=[3, 1, 2])
target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]])
pred = torch.zeros((1, 5, 6))
pred[:, 4] = 1

self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5)

def test_dice_with_multiple_ignored_classes_no_absent_score(self):
metric_multi_ignored = Dice(num_classes=5, ignore_index=[3, 1, 2])
target_multi_ignored = torch.tensor([[3, 1, 2, 0, 4, 4]])
pred = torch.zeros((1, 5, 6))
pred[:, 4] = 1
pred[0, 0, 3] = 2

self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1.0)

def test_pixelaccuracy_with_multiple_ignored_classes(self):
metric_multi_ignored = PixelAccuracy(ignore_label=[3, 1, 2])
target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]])
pred = torch.zeros((1, 5, 6))
pred[:, 4] = 1

self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1.0)


if __name__ == "__main__":
unittest.main()