diff --git a/src/super_gradients/training/metrics/segmentation_metrics.py b/src/super_gradients/training/metrics/segmentation_metrics.py index a4a847dbe2..d934bc1d27 100755 --- a/src/super_gradients/training/metrics/segmentation_metrics.py +++ b/src/super_gradients/training/metrics/segmentation_metrics.py @@ -1,8 +1,10 @@ +import typing + import numpy as np import torch import torchmetrics from torchmetrics import Metric -from typing import Optional, Tuple +from typing import Optional, Tuple, List, Union from torchmetrics.utilities.distributed import reduce from abc import ABC, abstractmethod @@ -122,6 +124,35 @@ def intersection_and_union(im_pred, im_lab, num_class): return area_inter, area_union +def _map_ignored_inds(target: torch.Tensor, ignore_index_list: List[int], unfiltered_num_classes: int) -> torch.Tensor: + """ + Creaetes a copy of target, mapping indices in range(unfiltered_num_classes) to range(unfiltered_num_classes-len( + ignore_index_list)+1). Indices in ignore_index_list are being mapped to 0, which can later on be used as + "ignore_index". + + Example: + >>>_map_ignored_inds(torch.tensor([0,1,2,3,4,5,6]), ignore_index_list=[3,5,1], unfiltered_num_classes=7) + >>> tensor([1, 0, 2, 0, 3, 0, 4]) + + + + :param target: torch.Tensor, tensor to perform the mapping on. + :param ignore_index_list: List[int], list of indices to map to 0 in the output tensor. + :param unfiltered_num_classes: int, Total number of possible class indices in target. + + :return: mapped tensor as described above. + """ + target_copy = torch.zeros_like(target) + all_unfiltered_classes = list(range(unfiltered_num_classes)) + filtered_classes = [i for i in all_unfiltered_classes if i not in ignore_index_list] + for mapped_idx in range(len(filtered_classes)): + cls_to_map = filtered_classes[mapped_idx] + map_val = mapped_idx + 1 + target_copy[target == cls_to_map] = map_val + + return target_copy + + class AbstractMetricsArgsPrepFn(ABC): """ Abstract preprocess metrics arguments class. @@ -164,7 +195,26 @@ def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten @register_metric(Metrics.PIXEL_ACCURACY) class PixelAccuracy(Metric): - def __init__(self, ignore_label=-100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): + """ + Pixel Accuracy + + Args: + ignore_label: Optional[Union[int, List[int]]], specifying a target class(es) to ignore. + If given, this class index does not contribute to the returned score, regardless of reduction method. + Has no effect if given an int that is not in the range [0, num_classes-1]. + By default, no index is ignored, and all classes are used. + IMPORTANT: reduction="none" alongside with a list of ignored indices is not supported and will raise an error. + reduction: a method to reduce metric score over labels: + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + metrics_args_prep_fn: Callable, inputs preprocess function applied on preds, target before updating metrics. + By default set to PreprocessSegmentationMetricsArgs(apply_arg_max=True) + """ + + def __init__(self, ignore_label: Union[int, List[int]] = -100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None): super().__init__(dist_sync_on_step=dist_sync_on_step) self.ignore_label = ignore_label self.greater_is_better = True @@ -174,13 +224,23 @@ def __init__(self, ignore_label=-100, dist_sync_on_step=False, metrics_args_prep def update(self, preds: torch.Tensor, target: torch.Tensor): predict, target = self.metrics_args_prep_fn(preds, target) + labeled_mask = self._handle_multiple_ignored_inds(target) - labeled_mask = target.ne(self.ignore_label) pixel_labeled = torch.sum(labeled_mask) pixel_correct = torch.sum((predict == target) * labeled_mask) self.total_correct += pixel_correct self.total_label += pixel_labeled + def _handle_multiple_ignored_inds(self, target): + if isinstance(self.ignore_label, typing.Iterable): + evaluated_classes_mask = torch.ones_like(target) + for ignored_label in self.ignore_label: + evaluated_classes_mask = evaluated_classes_mask.masked_fill(target.eq(ignored_label), 0) + else: + evaluated_classes_mask = target.ne(self.ignore_label) + + return evaluated_classes_mask + def compute(self): _total_correct = self.total_correct.cpu().detach().numpy().astype("int64") _total_label = self.total_label.cpu().detach().numpy().astype("int64") @@ -188,13 +248,63 @@ def compute(self): return pix_acc +def _handle_multiple_ignored_inds(ignore_index: Union[int, List[int]], num_classes: int): + """ + Helper method for variable assignment, prior to the + + super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold) + + call in segmentation metrics inheriting from torchmetrics.JaccardIndex. + When ignore_index is list, the num_classes being passed to the torchmetrics.JaccardIndex c'tor is set to be the one after + mapping of the ignored indices in ignore_index_list to 0. Hence, we set: + ignore_index=0, + And since we map all of the ignored indices to 0, it is if we removed them and introduces a new index: + num_classes = num_classes - len(ignore_index_list) +1 + Unfiltered num_classes is used in .update() for mapping of the original indice values. + Sets ignore_index to 0 + :param ignore_index: list or single int representing the class ind(ices) to ignore. + :param num_classes: int, num_classes (original, before mapping) being passed to segmentation metric classesׄ + :return:ignore_index, ignore_index_list, num_classes, unfiltered_num_classesignore_index, ignore_index_list, num_classes, unfiltered_num_classes + """ + if isinstance(ignore_index, typing.Iterable): + ignore_index_list = ignore_index + unfiltered_num_classes = num_classes + num_classes = num_classes - len(ignore_index_list) + 1 + ignore_index = 0 + else: + unfiltered_num_classes = num_classes + ignore_index_list = None + return ignore_index, ignore_index_list, num_classes, unfiltered_num_classes + + @register_metric(Metrics.IOU) class IoU(torchmetrics.JaccardIndex): + """ + IoU Metric + + Args: + num_classes: Number of classes in the dataset. + ignore_index: Optional[Union[int, List[int]]], specifying a target class(es) to ignore. + If given, this class index does not contribute to the returned score, regardless of reduction method. + Has no effect if given an int that is not in the range [0, num_classes-1]. + By default, no index is ignored, and all classes are used. + IMPORTANT: reduction="none" alongside with a list of ignored indices is not supported and will raise an error. + threshold: Threshold value for binary or multi-label probabilities. + reduction: a method to reduce metric score over labels: + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + metrics_args_prep_fn: Callable, inputs preprocess function applied on preds, target before updating metrics. + By default set to PreprocessSegmentationMetricsArgs(apply_arg_max=True) + """ + def __init__( self, num_classes: int, dist_sync_on_step: bool = False, - ignore_index: Optional[int] = None, + ignore_index: Optional[Union[int, List[int]]] = None, reduction: str = "elementwise_mean", threshold: float = 0.5, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None, @@ -202,18 +312,48 @@ def __init__( if num_classes <= 1: raise ValueError(f"IoU class only for multi-class usage! For binary usage, please call {BinaryIOU.__name__}") + if isinstance(ignore_index, typing.Iterable) and reduction == "none": + raise ValueError("passing multiple ignore indices ") + ignore_index, ignore_index_list, num_classes, unfiltered_num_classes = _handle_multiple_ignored_inds(ignore_index, num_classes) super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold) + + self.unfiltered_num_classes = unfiltered_num_classes + self.ignore_index_list = ignore_index_list self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True) self.greater_is_better = True def update(self, preds, target: torch.Tensor): preds, target = self.metrics_args_prep_fn(preds, target) + if self.ignore_index_list is not None: + target = _map_ignored_inds(target, self.ignore_index_list, self.unfiltered_num_classes) + preds = _map_ignored_inds(preds, self.ignore_index_list, self.unfiltered_num_classes) super().update(preds=preds, target=target) @register_metric(Metrics.DICE) class Dice(torchmetrics.JaccardIndex): + """ + Dice Coefficient Metric + + Args: + num_classes: Number of classes in the dataset. + ignore_index: Optional[Union[int, List[int]]], specifying a target class(es) to ignore. + If given, this class index does not contribute to the returned score, regardless of reduction method. + Has no effect if given an int that is not in the range [0, num_classes-1]. + By default, no index is ignored, and all classes are used. + IMPORTANT: reduction="none" alongside with a list of ignored indices is not supported and will raise an error. + threshold: Threshold value for binary or multi-label probabilities. + reduction: a method to reduce metric score over labels: + + - ``'elementwise_mean'``: takes the mean (default) + - ``'sum'``: takes the sum + - ``'none'``: no reduction will be applied + + metrics_args_prep_fn: Callable, inputs preprocess function applied on preds, target before updating metrics. + By default set to PreprocessSegmentationMetricsArgs(apply_arg_max=True) + """ + def __init__( self, num_classes: int, @@ -227,12 +367,20 @@ def __init__( if num_classes <= 1: raise ValueError(f"Dice class only for multi-class usage! For binary usage, please call {BinaryDice.__name__}") + ignore_index, ignore_index_list, num_classes, unfiltered_num_classes = _handle_multiple_ignored_inds(ignore_index, num_classes) + super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold) + + self.ignore_index_list = ignore_index_list + self.unfiltered_num_classes = unfiltered_num_classes self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True) self.greater_is_better = True def update(self, preds, target: torch.Tensor): preds, target = self.metrics_args_prep_fn(preds, target) + if self.ignore_index_list is not None: + target = _map_ignored_inds(target, self.ignore_index_list, self.unfiltered_num_classes) + preds = _map_ignored_inds(preds, self.ignore_index_list, self.unfiltered_num_classes) super().update(preds=preds, target=target) def compute(self) -> torch.Tensor: diff --git a/tests/deci_core_unit_test_suite_runner.py b/tests/deci_core_unit_test_suite_runner.py index d7beaa7658..7d7289ac92 100644 --- a/tests/deci_core_unit_test_suite_runner.py +++ b/tests/deci_core_unit_test_suite_runner.py @@ -29,6 +29,7 @@ from tests.unit_tests.load_checkpoint_test import LoadCheckpointTest from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest from tests.unit_tests.max_batches_loop_break_test import MaxBatchesLoopBreakTest +from tests.unit_tests.multiple_ignore_indices_segmentation_metrics_test import TestSegmentationMetricsMultipleIgnored from tests.unit_tests.phase_delegates_test import ContextMethodsTest from tests.unit_tests.pose_estimation_dataset_test import TestPoseEstimationDataset from tests.unit_tests.preprocessing_unit_test import PreprocessingUnitTest @@ -141,6 +142,7 @@ def _add_modules_to_unit_tests_suite(self): self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLONAS)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DeprecationsUnitTest)) self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestMinSamplesSingleNode)) + self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestSegmentationMetricsMultipleIgnored)) def _add_modules_to_end_to_end_tests_suite(self): """ diff --git a/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py b/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py new file mode 100644 index 0000000000..34f8619f77 --- /dev/null +++ b/tests/unit_tests/multiple_ignore_indices_segmentation_metrics_test.py @@ -0,0 +1,57 @@ +import unittest + +import torch + +from super_gradients.training.metrics import IoU, PixelAccuracy, Dice + + +class TestSegmentationMetricsMultipleIgnored(unittest.TestCase): + def test_iou_with_multiple_ignored_classes_and_absent_score(self): + metric_multi_ignored = IoU(num_classes=5, ignore_index=[3, 1, 2]) + target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]]) + pred = torch.zeros((1, 5, 6)) + pred[:, 4] = 1 + + # preds after onehot -> [4,4,4,4,4,4] + # (1 + 0)/2 : 1.0 for class 4 score and 0 for absent score for class 0 + self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5) + + def test_iou_with_multiple_ignored_classes_no_absent_score(self): + metric_multi_ignored = IoU(num_classes=5, ignore_index=[3, 1, 2]) + target_multi_ignored = torch.tensor([[3, 1, 2, 0, 4, 4]]) + pred = torch.zeros((1, 5, 6)) + pred[:, 4] = 1 + pred[0, 0, 3] = 2 + + # preds after onehot -> [4,4,4,0,4,4] + # (1 + 1)/2 : 1.0 for class 4 score and 1 for class 0 + self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1) + + def test_dice_with_multiple_ignored_classes_and_absent_score(self): + metric_multi_ignored = Dice(num_classes=5, ignore_index=[3, 1, 2]) + target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]]) + pred = torch.zeros((1, 5, 6)) + pred[:, 4] = 1 + + self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5) + + def test_dice_with_multiple_ignored_classes_no_absent_score(self): + metric_multi_ignored = Dice(num_classes=5, ignore_index=[3, 1, 2]) + target_multi_ignored = torch.tensor([[3, 1, 2, 0, 4, 4]]) + pred = torch.zeros((1, 5, 6)) + pred[:, 4] = 1 + pred[0, 0, 3] = 2 + + self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1.0) + + def test_pixelaccuracy_with_multiple_ignored_classes(self): + metric_multi_ignored = PixelAccuracy(ignore_label=[3, 1, 2]) + target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]]) + pred = torch.zeros((1, 5, 6)) + pred[:, 4] = 1 + + self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1.0) + + +if __name__ == "__main__": + unittest.main()