Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 858 ignore multiple labels segmentation metrics support #1177

Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 117 additions & 4 deletions src/super_gradients/training/metrics/segmentation_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
import torchmetrics
from torchmetrics import Metric
from typing import Optional, Tuple
from typing import Optional, Tuple, List, Union
from torchmetrics.utilities.distributed import reduce
from abc import ABC, abstractmethod

Expand Down Expand Up @@ -122,6 +122,18 @@ def intersection_and_union(im_pred, im_lab, num_class):
return area_inter, area_union


def _map_ignored_inds(target: torch.Tensor, ignore_index_list, unfiltered_num_classes) -> torch.Tensor:
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
target_copy = torch.zeros_like(target)
all_unfiltered_classes = list(range(unfiltered_num_classes))
filtered_classes = [i for i in all_unfiltered_classes if i not in ignore_index_list]
for mapped_idx in range(len(filtered_classes)):
cls_to_map = filtered_classes[mapped_idx]
map_val = mapped_idx + 1
target_copy[target == cls_to_map] = map_val

return target_copy


class AbstractMetricsArgsPrepFn(ABC):
"""
Abstract preprocess metrics arguments class.
Expand Down Expand Up @@ -164,7 +176,26 @@ def __call__(self, preds, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Ten

@register_metric(Metrics.PIXEL_ACCURACY)
class PixelAccuracy(Metric):
def __init__(self, ignore_label=-100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None):
"""
Pixel Accuracy

Args:
ignore_label: Optional[Union[int, List[int]]], specifying a target class(es) to ignore.
If given, this class index does not contribute to the returned score, regardless of reduction method.
Has no effect if given an int that is not in the range [0, num_classes-1].
By default, no index is ignored, and all classes are used.
IMPORTANT: reduction="none" alongside with a list of ignored indices is not supported and will raise an error.
reduction: a method to reduce metric score over labels:

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

metrics_args_prep_fn: Callable, inputs preprocess function applied on preds, target before updating metrics.
By default set to PreprocessSegmentationMetricsArgs(apply_arg_max=True)
"""

def __init__(self, ignore_label: Union[int, List[int]] = -100, dist_sync_on_step=False, metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None):
super().__init__(dist_sync_on_step=dist_sync_on_step)
self.ignore_label = ignore_label
self.greater_is_better = True
Expand All @@ -174,27 +205,73 @@ def __init__(self, ignore_label=-100, dist_sync_on_step=False, metrics_args_prep

def update(self, preds: torch.Tensor, target: torch.Tensor):
predict, target = self.metrics_args_prep_fn(preds, target)
labeled_mask = self._handle_multiple_ignored_inds(target)

labeled_mask = target.ne(self.ignore_label)
pixel_labeled = torch.sum(labeled_mask)
pixel_correct = torch.sum((predict == target) * labeled_mask)
self.total_correct += pixel_correct
self.total_label += pixel_labeled

def _handle_multiple_ignored_inds(self, target):
if isinstance(self.ignore_label, list):
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
labeled_mask = None
for ignored_label in self.ignore_label:
if labeled_mask is None:
labeled_mask = target.ne(ignored_label)
else:
labeled_mask = torch.logical_and(labeled_mask, target.ne(ignored_label))
else:
labeled_mask = target.ne(self.ignore_label)

return labeled_mask
shaydeci marked this conversation as resolved.
Show resolved Hide resolved

def compute(self):
_total_correct = self.total_correct.cpu().detach().numpy().astype("int64")
_total_label = self.total_label.cpu().detach().numpy().astype("int64")
pix_acc = np.float64(1.0) * _total_correct / (np.spacing(1, dtype=np.float64) + _total_label)
return pix_acc


def _handle_multiple_ignored_inds(ignore_index, num_classes):
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(ignore_index, list):
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
ignore_index_list = ignore_index
unfiltered_num_classes = num_classes
num_classes = num_classes - len(ignore_index_list) + 1
ignore_index = 0
else:
unfiltered_num_classes = num_classes
ignore_index_list = None
return ignore_index, ignore_index_list, num_classes, unfiltered_num_classes


@register_metric(Metrics.IOU)
class IoU(torchmetrics.JaccardIndex):
"""
IoU Metric

Args:
num_classes: Number of classes in the dataset.
ignore_index: Optional[Union[int, List[int]]], specifying a target class(es) to ignore.
If given, this class index does not contribute to the returned score, regardless of reduction method.
Has no effect if given an int that is not in the range [0, num_classes-1].
By default, no index is ignored, and all classes are used.
IMPORTANT: reduction="none" alongside with a list of ignored indices is not supported and will raise an error.
threshold: Threshold value for binary or multi-label probabilities.
reduction: a method to reduce metric score over labels:

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

metrics_args_prep_fn: Callable, inputs preprocess function applied on preds, target before updating metrics.
By default set to PreprocessSegmentationMetricsArgs(apply_arg_max=True)
"""

def __init__(
self,
num_classes: int,
dist_sync_on_step: bool = False,
ignore_index: Optional[int] = None,
ignore_index: Optional[Union[int, List[int]]] = None,
reduction: str = "elementwise_mean",
threshold: float = 0.5,
metrics_args_prep_fn: Optional[AbstractMetricsArgsPrepFn] = None,
Expand All @@ -203,17 +280,46 @@ def __init__(
if num_classes <= 1:
raise ValueError(f"IoU class only for multi-class usage! For binary usage, please call {BinaryIOU.__name__}")

ignore_index, ignore_index_list, num_classes, unfiltered_num_classes = _handle_multiple_ignored_inds(ignore_index, num_classes)

super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold)

self.unfiltered_num_classes = unfiltered_num_classes
self.ignore_index_list = ignore_index_list
self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True)
self.greater_is_better = True

def update(self, preds, target: torch.Tensor):
preds, target = self.metrics_args_prep_fn(preds, target)
if self.ignore_index_list is not None:
target = _map_ignored_inds(target, self.ignore_index_list, self.unfiltered_num_classes)
preds = _map_ignored_inds(preds, self.ignore_index_list, self.unfiltered_num_classes)
super().update(preds=preds, target=target)


@register_metric(Metrics.DICE)
class Dice(torchmetrics.JaccardIndex):
"""
Dice Coefficient Metric

Args:
num_classes: Number of classes in the dataset.
ignore_index: Optional[Union[int, List[int]]], specifying a target class(es) to ignore.
If given, this class index does not contribute to the returned score, regardless of reduction method.
Has no effect if given an int that is not in the range [0, num_classes-1].
By default, no index is ignored, and all classes are used.
IMPORTANT: reduction="none" alongside with a list of ignored indices is not supported and will raise an error.
threshold: Threshold value for binary or multi-label probabilities.
reduction: a method to reduce metric score over labels:

- ``'elementwise_mean'``: takes the mean (default)
- ``'sum'``: takes the sum
- ``'none'``: no reduction will be applied

metrics_args_prep_fn: Callable, inputs preprocess function applied on preds, target before updating metrics.
By default set to PreprocessSegmentationMetricsArgs(apply_arg_max=True)
"""

def __init__(
self,
num_classes: int,
Expand All @@ -227,12 +333,19 @@ def __init__(
if num_classes <= 1:
raise ValueError(f"Dice class only for multi-class usage! For binary usage, please call {BinaryDice.__name__}")

ignore_index, ignore_index_list, num_classes, unfiltered_num_classes = _handle_multiple_ignored_inds(ignore_index, num_classes)

super().__init__(num_classes=num_classes, dist_sync_on_step=dist_sync_on_step, ignore_index=ignore_index, reduction=reduction, threshold=threshold)

self.ignore_index_list = ignore_index_list
self.unfiltered_num_classes = unfiltered_num_classes
self.metrics_args_prep_fn = metrics_args_prep_fn or PreprocessSegmentationMetricsArgs(apply_arg_max=True)
self.greater_is_better = True

def update(self, preds, target: torch.Tensor):
preds, target = self.metrics_args_prep_fn(preds, target)
if self.ignore_index_list is not None:
target = _map_ignored_inds(target, self.ignore_index_list, self.ignore_index)
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
super().update(preds=preds, target=target)

def compute(self) -> torch.Tensor:
Expand Down
2 changes: 2 additions & 0 deletions tests/deci_core_unit_test_suite_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tests.unit_tests.load_checkpoint_test import LoadCheckpointTest
from tests.unit_tests.local_ckpt_head_replacement_test import LocalCkptHeadReplacementTest
from tests.unit_tests.max_batches_loop_break_test import MaxBatchesLoopBreakTest
from tests.unit_tests.multiple_ignore_indices_segmentation_metrics_test import TestSegmentationMetricsMultipleIgnored
from tests.unit_tests.phase_delegates_test import ContextMethodsTest
from tests.unit_tests.pose_estimation_dataset_test import TestPoseEstimationDataset
from tests.unit_tests.preprocessing_unit_test import PreprocessingUnitTest
Expand Down Expand Up @@ -141,6 +142,7 @@ def _add_modules_to_unit_tests_suite(self):
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestYOLONAS))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(DeprecationsUnitTest))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestMinSamplesSingleNode))
self.unit_tests_suite.addTest(self.test_loader.loadTestsFromModule(TestSegmentationMetricsMultipleIgnored))

def _add_modules_to_end_to_end_tests_suite(self):
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import unittest

import torch

from super_gradients.training.metrics import IoU, PixelAccuracy, Dice


class TestSegmentationMetricsMultipleIgnored(unittest.TestCase):
def test_iou_with_multiple_ignored_classes_and_absent_score(self):
metric_multi_ignored = IoU(num_classes=5, ignore_index=[3, 1, 2])
target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]])
pred = torch.zeros((1, 5, 6))
pred[:, 4] = 1

# preds after onehot -> [4,4,4,4,4,4]
# (1 + 0)/2 : 1.0 for class 4 score and 0 for absent score for class 0
self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5)

def test_iou_with_multiple_ignored_classes_no_absent_score(self):
metric_multi_ignored = IoU(num_classes=5, ignore_index=[3, 1, 2])
target_multi_ignored = torch.tensor([[3, 1, 2, 0, 4, 4]])
pred = torch.zeros((1, 5, 6))
pred[:, 4] = 1
pred[0, 0, 3] = 2

# preds after onehot -> [4,4,4,0,4,4]
# (1 + 1)/2 : 1.0 for class 4 score and 1 for class 0
self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1)

def test_dice_with_multiple_ignored_classes_and_absent_score(self):
metric_multi_ignored = Dice(num_classes=5, ignore_index=[3, 1, 2])
target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]])
pred = torch.zeros((1, 5, 6))
pred[:, 4] = 1

self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5)

def test_dice_with_multiple_ignored_classes_no_absent_score(self):
metric_multi_ignored = Dice(num_classes=5, ignore_index=[3, 1, 2])
target_multi_ignored = torch.tensor([[3, 1, 2, 0, 4, 4]])
pred = torch.zeros((1, 5, 6))
pred[:, 4] = 1
pred[0, 0, 3] = 2

self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 0.5)

def test_pixelaccuracy_with_multiple_ignored_classes(self):
metric_multi_ignored = PixelAccuracy(ignore_label=[3, 1, 2])
target_multi_ignored = torch.tensor([[3, 1, 2, 4, 4, 4]])
pred = torch.zeros((1, 5, 6))
pred[:, 4] = 1

self.assertEqual(metric_multi_ignored(pred, target_multi_ignored), 1.0)


if __name__ == "__main__":
unittest.main()