From 123c778daafb79b4a93db6207a2d194db373d9ca Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Wed, 31 Jul 2024 01:15:16 -0400 Subject: [PATCH 01/11] Fix generalized dice computation Signed-off-by: Suraj Pai Similar functionality to torchmetrics Update Lint and update sum_over_labels Update docstring Update docstring --- monai/metrics/generalized_dice.py | 112 +++++++++++++------------ tests/test_compute_generalized_dice.py | 40 +++++---- 2 files changed, 84 insertions(+), 68 deletions(-) diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index e56bd46592..047bfd0ab2 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -20,109 +20,108 @@ class GeneralizedDiceScore(CumulativeIterationMetric): - """Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in: + """ + Compute the Generalized Dice Score metric between tensors. + This metric is the complement of the Generalized Dice Loss defined in: Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning - loss function for highly unbalanced segmentations. DLMIA 2017. + loss function for highly unbalanced segmentations. DLMIA 2017. - The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first - or batch-first tensors, i.e., CHW[D] or BCHW[D]. + The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D]. Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`. Args: - include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the + include_background: Whether to include the background class (assumed to be in channel 0) in the score computation. Defaults to True. - reduction (str, optional): define mode of reduction to the metrics. Available reduction modes: - {``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction. - weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform + reduction: Define mode of reduction to the metrics. Available reduction modes: + {``"none"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean"`, ``"sum"`}. Defaults to ``"mean"``. + If "none", will not do reduction. + weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. Raises: - ValueError: when the `weight_type` is not one of {``"none"``, ``"mean"``, ``"sum"``}. + ValueError: When the `reduction` is not one of MetricReduction enum. """ def __init__( - self, - include_background: bool = True, - reduction: MetricReduction | str = MetricReduction.MEAN_BATCH, - weight_type: Weight | str = Weight.SQUARE, + self, include_background: bool = True, reduction: str = "mean", weight_type: Weight | str = Weight.SQUARE ) -> None: super().__init__() self.include_background = include_background - reduction_options = [ - "none", - "mean_batch", - "sum_batch", - MetricReduction.NONE, - MetricReduction.MEAN_BATCH, - MetricReduction.SUM_BATCH, - ] - self.reduction = reduction - if self.reduction not in reduction_options: - raise ValueError(f"reduction must be one of {reduction_options}") + self.reduction = look_up_option(reduction, MetricReduction) self.weight_type = look_up_option(weight_type, Weight) + self.sum_over_labels = self.reduction == MetricReduction.SUM or self.reduction == MetricReduction.MEAN - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] - """Computes the Generalized Dice Score and returns a tensor with its per image values. + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + Computes the Generalized Dice Score and returns a tensor with its per image values. Args: - y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, + y_pred: Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. - y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. + y: Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. + + Returns: + torch.Tensor: Per batch and per class Generalized Dice Score. Raises: - ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape. + ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape. """ return compute_generalized_dice( - y_pred=y_pred, y=y, include_background=self.include_background, weight_type=self.weight_type + y_pred=y_pred, + y=y, + include_background=self.include_background, + weight_type=self.weight_type, + sum_over_labels=self.sum_over_labels, ) - def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor: + def aggregate(self) -> torch.Tensor: """ Execute reduction logic for the output of `compute_generalized_dice`. - Args: - reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics. - Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}. - Defaults to ``"mean"``. If "none", will not do reduction. + Returns: + torch.Tensor: Aggregated metric value. + + Raises: + ValueError: If the data to aggregate is not a PyTorch Tensor. """ data = self.get_buffer() if not isinstance(data, torch.Tensor): raise ValueError("The data to aggregate must be a PyTorch Tensor.") - # Validate reduction argument if specified - if reduction is not None: - reduction_options = ["none", "mean", "sum", "mean_batch", "sum_batch"] - if reduction not in reduction_options: - raise ValueError(f"reduction must be one of {reduction_options}") - # Do metric reduction and return - f, _ = do_metric_reduction(data, reduction or self.reduction) + f, _ = do_metric_reduction(data, self.reduction) return f def compute_generalized_dice( - y_pred: torch.Tensor, y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE + y_pred: torch.Tensor, + y: torch.Tensor, + include_background: bool = True, + weight_type: Weight | str = Weight.SQUARE, + sum_over_labels: bool = False, ) -> torch.Tensor: - """Computes the Generalized Dice Score and returns a tensor with its per image values. + """ + Computes the Generalized Dice Score and returns a tensor with its per image values. Args: - y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format + y_pred: Binarized segmentation model output. It should be binarized, in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. - y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`. - include_background (bool, optional): whether to include score computation on the first channel of the + y: Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`. + include_background: Whether to include score computation on the first channel of the predicted output. Defaults to True. - weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to + weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. + sum_over_labels: Whether to sum the numerator and denominator across all labels before the final computation. Returns: - torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. + torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. Raises: - ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions, + ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions, or `y_pred` and `y` don't have the same shape. """ # Ensure tensors have at least 3 dimensions and have the same shape @@ -158,16 +157,21 @@ def compute_generalized_dice( b[infs] = 0 b[infs] = torch.max(b) - # Compute the weighted numerator and denominator, summing along the class axis - numer = 2.0 * (intersection * w).sum(dim=1) - denom = (denominator * w).sum(dim=1) + # Compute the weighted numerator and denominator, summing along the class axis when sum_over_labels is True + if sum_over_labels: + numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True) + denom = (denominator * w).sum(dim=1, keepdim=True) + y_pred_o = y_pred_o.sum(dim=-1, keepdim=True) + else: + numer = 2.0 * (intersection * w) + denom = denominator * w + y_pred_o = y_pred_o # Compute the score generalized_dice_score = numer / denom # Handle zero division. Where denom == 0 and the prediction volume is 0, score is 1. # Where denom == 0 but the prediction volume is not 0, score is 0 - y_pred_o = y_pred_o.sum(dim=-1) denom_zeros = denom == 0 generalized_dice_score[denom_zeros] = torch.where( (y_pred_o == 0)[denom_zeros], diff --git a/tests/test_compute_generalized_dice.py b/tests/test_compute_generalized_dice.py index e04444e988..fd7745245c 100644 --- a/tests/test_compute_generalized_dice.py +++ b/tests/test_compute_generalized_dice.py @@ -22,7 +22,7 @@ _device = "cuda:0" if torch.cuda.is_available() else "cpu" # keep background -TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1) +TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1) { "y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device), "y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device), @@ -32,7 +32,7 @@ ] # remove background -TEST_CASE_2 = [ # y (2, 1, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background) +TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 2) (no background) { "y_pred": torch.tensor( [ @@ -48,11 +48,11 @@ ), "include_background": False, }, - [0.1667, 0.6667], + [0.416667], ] # should return 0 for both cases -TEST_CASE_3 = [ +TEST_CASE_3 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 3) { "y_pred": torch.tensor( [ @@ -68,7 +68,7 @@ ), "include_background": True, }, - [0.0, 0.0], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], ] TEST_CASE_4 = [ @@ -87,11 +87,11 @@ ] ), }, - [0.5455], + [0.678571, 0.2, 0.333333], ] TEST_CASE_5 = [ - {"include_background": True, "reduction": "sum_batch"}, + {"include_background": True, "reduction": "sum"}, { "y_pred": torch.tensor( [ @@ -106,16 +106,28 @@ ] ), }, - 1.0455, + [1.045455], ] -TEST_CASE_6 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [1.0000, 1.0000]] +TEST_CASE_6 = [ + {"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, + [[1.0000, 1.0000], [1.0000, 1.0000]], +] -TEST_CASE_7 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [0.0000, 0.0000]] +TEST_CASE_7 = [ + {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, + [[0.0000, 0.0000], [0.0000, 0.0000]], +] -TEST_CASE_8 = [{"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [0.0000, 0.0000]] +TEST_CASE_8 = [ + {"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, + [[0.0000, 0.0000], [0.0000, 0.0000]], +] -TEST_CASE_9 = [{"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [1.0000, 1.0000]] +TEST_CASE_9 = [ + {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, + [[1.0000, 1.0000], [1.0000, 1.0000]], +] class TestComputeGeneralizedDiceScore(unittest.TestCase): @@ -126,7 +138,7 @@ def test_device(self, input_data, _expected_value): np.testing.assert_equal(result.device, input_data["y_pred"].device) # Functional part tests - @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]) + @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]) def test_value(self, input_data, expected_value): result = compute_generalized_dice(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) @@ -146,7 +158,7 @@ def test_value_class(self, input_data, expected_value): vals["y"] = input_data.pop("y") generalized_dice_score = GeneralizedDiceScore(**input_data) generalized_dice_score(**vals) - result = generalized_dice_score.aggregate(reduction="none") + result = generalized_dice_score.aggregate() np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) # Aggregation tests From 2081cd263176053a78bc987e301a0ceb6a417b93 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Tue, 27 Aug 2024 14:15:57 -0400 Subject: [PATCH 02/11] Add updates from comments --- monai/metrics/generalized_dice.py | 35 +++++++++++++++++++++++-------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index 047bfd0ab2..c743379700 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -14,7 +14,7 @@ import torch from monai.metrics.utils import do_metric_reduction, ignore_background -from monai.utils import MetricReduction, Weight, look_up_option +from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option from .metric import CumulativeIterationMetric @@ -44,16 +44,27 @@ class GeneralizedDiceScore(CumulativeIterationMetric): ValueError: When the `reduction` is not one of MetricReduction enum. """ + @deprecated_arg_default( + "reduction", + MetricReduction.MEAN_BATCH, + MetricReduction.MEAN, + since="1.4.0", + replaced="1.4.0", + msg_suffix="Old versions computed `mean` when `mean_batch` was provided due to bug in reduction", + ) def __init__( - self, include_background: bool = True, reduction: str = "mean", weight_type: Weight | str = Weight.SQUARE + self, + include_background: bool = True, + reduction: str = MetricReduction.MEAN, + weight_type: Weight | str = Weight.SQUARE, ) -> None: super().__init__() self.include_background = include_background self.reduction = look_up_option(reduction, MetricReduction) self.weight_type = look_up_option(weight_type, Weight) - self.sum_over_labels = self.reduction == MetricReduction.SUM or self.reduction == MetricReduction.MEAN + self.sum_over_classes = self.reduction == MetricReduction.SUM or self.reduction == MetricReduction.MEAN - def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ Computes the Generalized Dice Score and returns a tensor with its per image values. @@ -63,7 +74,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor y: Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. Returns: - torch.Tensor: Per batch and per class Generalized Dice Score. + torch.Tensor: Generalized Dice Score averaged across batch and class Raises: ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape. @@ -73,10 +84,16 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor y=y, include_background=self.include_background, weight_type=self.weight_type, - sum_over_labels=self.sum_over_labels, + sum_over_classes=self.sum_over_classes, ) - def aggregate(self) -> torch.Tensor: + @deprecated_arg( + "reduction", + since="1.4.0", + removed="1.7.0", + msg_suffix="Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute", + ) + def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor: """ Execute reduction logic for the output of `compute_generalized_dice`. @@ -101,7 +118,7 @@ def compute_generalized_dice( y: torch.Tensor, include_background: bool = True, weight_type: Weight | str = Weight.SQUARE, - sum_over_labels: bool = False, + sum_over_classes: bool = False, ) -> torch.Tensor: """ Computes the Generalized Dice Score and returns a tensor with its per image values. @@ -158,7 +175,7 @@ def compute_generalized_dice( b[infs] = torch.max(b) # Compute the weighted numerator and denominator, summing along the class axis when sum_over_labels is True - if sum_over_labels: + if sum_over_classes: numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True) denom = (denominator * w).sum(dim=1, keepdim=True) y_pred_o = y_pred_o.sum(dim=-1, keepdim=True) From 3805fea8a2a38871a1b042082036ecf12499b633 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Tue, 27 Aug 2024 14:21:33 -0400 Subject: [PATCH 03/11] Minor doc fixes --- monai/metrics/generalized_dice.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index c743379700..feda453935 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -35,8 +35,8 @@ class GeneralizedDiceScore(CumulativeIterationMetric): include_background: Whether to include the background class (assumed to be in channel 0) in the score computation. Defaults to True. reduction: Define mode of reduction to the metrics. Available reduction modes: - {``"none"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean"`, ``"sum"`}. Defaults to ``"mean"``. - If "none", will not do reduction. + {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. @@ -46,8 +46,8 @@ class GeneralizedDiceScore(CumulativeIterationMetric): @deprecated_arg_default( "reduction", - MetricReduction.MEAN_BATCH, - MetricReduction.MEAN, + old_default=MetricReduction.MEAN_BATCH, + new_default=MetricReduction.MEAN, since="1.4.0", replaced="1.4.0", msg_suffix="Old versions computed `mean` when `mean_batch` was provided due to bug in reduction", @@ -55,7 +55,7 @@ class GeneralizedDiceScore(CumulativeIterationMetric): def __init__( self, include_background: bool = True, - reduction: str = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN, weight_type: Weight | str = Weight.SQUARE, ) -> None: super().__init__() @@ -69,9 +69,9 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor Computes the Generalized Dice Score and returns a tensor with its per image values. Args: - y_pred: Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, + y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. - y: Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. + y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`. Returns: torch.Tensor: Generalized Dice Score averaged across batch and class @@ -124,15 +124,15 @@ def compute_generalized_dice( Computes the Generalized Dice Score and returns a tensor with its per image values. Args: - y_pred: Binarized segmentation model output. It should be binarized, in one-hot format + y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions. - y: Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`. + y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`. include_background: Whether to include score computation on the first channel of the predicted output. Defaults to True. - weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to + weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. - sum_over_labels: Whether to sum the numerator and denominator across all labels before the final computation. + sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation. Returns: torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes]. From fb6e66067dcd10e5b22aa95847b9db80ba402662 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Tue, 27 Aug 2024 19:58:44 -0400 Subject: [PATCH 04/11] Update version string --- monai/metrics/generalized_dice.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index feda453935..ffcc1e0796 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -48,8 +48,8 @@ class GeneralizedDiceScore(CumulativeIterationMetric): "reduction", old_default=MetricReduction.MEAN_BATCH, new_default=MetricReduction.MEAN, - since="1.4.0", - replaced="1.4.0", + since="1.3.3", + replaced="1.3.3", msg_suffix="Old versions computed `mean` when `mean_batch` was provided due to bug in reduction", ) def __init__( @@ -89,7 +89,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor @deprecated_arg( "reduction", - since="1.4.0", + since="1.3.3", removed="1.7.0", msg_suffix="Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute", ) From 4ea6a2c9cfed83d7eac9914702561877264e2256 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 5 Sep 2024 11:26:11 -0400 Subject: [PATCH 05/11] Update monai/metrics/generalized_dice.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Suraj Pai --- monai/metrics/generalized_dice.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index ffcc1e0796..c2eb847737 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -48,9 +48,9 @@ class GeneralizedDiceScore(CumulativeIterationMetric): "reduction", old_default=MetricReduction.MEAN_BATCH, new_default=MetricReduction.MEAN, - since="1.3.3", - replaced="1.3.3", - msg_suffix="Old versions computed `mean` when `mean_batch` was provided due to bug in reduction", + since="1.4.0", + replaced="1.5.0", + msg_suffix="Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'.", ) def __init__( self, From 2bc6dd82b80905e1a4206476cd84a16ce7ecc94d Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 5 Sep 2024 12:29:45 -0400 Subject: [PATCH 06/11] Update monai/metrics/generalized_dice.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Suraj Pai --- monai/metrics/generalized_dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index c2eb847737..92e88098ac 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -55,7 +55,7 @@ class GeneralizedDiceScore(CumulativeIterationMetric): def __init__( self, include_background: bool = True, - reduction: MetricReduction | str = MetricReduction.MEAN, + reduction: MetricReduction | str = MetricReduction.MEAN_BATCH, weight_type: Weight | str = Weight.SQUARE, ) -> None: super().__init__() From 827406095c24d3f50bc2d581fa65eea001e5321c Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 5 Sep 2024 13:10:13 -0400 Subject: [PATCH 07/11] Update reduction for channel cases --- monai/metrics/generalized_dice.py | 7 ++++++- tests/test_compute_generalized_dice.py | 26 +++++++++++++++++++++++--- 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index 92e88098ac..70041ad1c9 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -62,7 +62,12 @@ def __init__( self.include_background = include_background self.reduction = look_up_option(reduction, MetricReduction) self.weight_type = look_up_option(weight_type, Weight) - self.sum_over_classes = self.reduction == MetricReduction.SUM or self.reduction == MetricReduction.MEAN + self.sum_over_classes = self.reduction in { + MetricReduction.SUM, + MetricReduction.MEAN, + MetricReduction.MEAN_CHANNEL, + MetricReduction.SUM_CHANNEL, + } def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override] """ diff --git a/tests/test_compute_generalized_dice.py b/tests/test_compute_generalized_dice.py index fd7745245c..448a17eaa5 100644 --- a/tests/test_compute_generalized_dice.py +++ b/tests/test_compute_generalized_dice.py @@ -32,7 +32,7 @@ ] # remove background -TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 2) (no background) +TEST_CASE_2 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2) (no background) with GeneralizedDiceScore { "y_pred": torch.tensor( [ @@ -47,8 +47,9 @@ ] ), "include_background": False, + "reduction": "mean_batch", }, - [0.416667], + [0.583333, 0.333333], ] # should return 0 for both cases @@ -129,6 +130,25 @@ [[1.0000, 1.0000], [1.0000, 1.0000]], ] +TEST_CASE_10 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2) + {"include_background": True, "reduction": "mean_channel"}, + { + "y_pred": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], + [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], + ] + ), + "y": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], + ] + ), + }, + [0.545455, 0.545455], +] + class TestComputeGeneralizedDiceScore(unittest.TestCase): @@ -162,7 +182,7 @@ def test_value_class(self, input_data, expected_value): np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) # Aggregation tests - @parameterized.expand([TEST_CASE_4, TEST_CASE_5]) + @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_10]) def test_nans_class(self, params, input_data, expected_value): generalized_dice_score = GeneralizedDiceScore(**params) generalized_dice_score(**input_data) From 2e7e04cef7120b0826a0dbed3baf8b595dbbe2d5 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 5 Sep 2024 13:28:42 -0400 Subject: [PATCH 08/11] Fix line length Signed-off-by: Suraj Pai --- monai/metrics/generalized_dice.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index 70041ad1c9..3ebfd6f386 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -50,7 +50,10 @@ class GeneralizedDiceScore(CumulativeIterationMetric): new_default=MetricReduction.MEAN, since="1.4.0", replaced="1.5.0", - msg_suffix="Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'.", + msg_suffix=( + "Old versions computed `mean` when `mean_batch` was provided due to bug in reduction, " + "If you want to retain the old behavior (calculating the mean), please explicitly set the parameter to 'mean'." + ), ) def __init__( self, From 72ed145abfafe3939a435865a946165f55c61ee5 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Fri, 6 Sep 2024 10:34:41 -0400 Subject: [PATCH 09/11] Update monai/metrics/generalized_dice.py Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Suraj Pai --- monai/metrics/generalized_dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index 3ebfd6f386..aa80b467a0 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -182,7 +182,7 @@ def compute_generalized_dice( b[infs] = 0 b[infs] = torch.max(b) - # Compute the weighted numerator and denominator, summing along the class axis when sum_over_labels is True + # Compute the weighted numerator and denominator, summing along the class axis when sum_over_classes is True if sum_over_classes: numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True) denom = (denominator * w).sum(dim=1, keepdim=True) From a76264982d627082eddbf1d27a572e469b2010c3 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Fri, 6 Sep 2024 11:38:16 -0400 Subject: [PATCH 10/11] Fix docs + improve tests _________ DCO Remediation Commit for Suraj Pai I, Suraj Pai , hereby add my Signed-off-by to this commit: 2081cd263176053a78bc987e301a0ceb6a417b93 I, Suraj Pai , hereby add my Signed-off-by to this commit: 3805fea8a2a38871a1b042082036ecf12499b633 I, Suraj Pai , hereby add my Signed-off-by to this commit: fb6e66067dcd10e5b22aa95847b9db80ba402662 I, Suraj Pai , hereby add my Signed-off-by to this commit: 827406095c24d3f50bc2d581fa65eea001e5321c Signed-off-by: Suraj Pai --- monai/metrics/generalized_dice.py | 2 +- tests/test_compute_generalized_dice.py | 154 ++++++++++++++++--------- 2 files changed, 100 insertions(+), 56 deletions(-) diff --git a/monai/metrics/generalized_dice.py b/monai/metrics/generalized_dice.py index 3ebfd6f386..2e62023aca 100644 --- a/monai/metrics/generalized_dice.py +++ b/monai/metrics/generalized_dice.py @@ -35,7 +35,7 @@ class GeneralizedDiceScore(CumulativeIterationMetric): include_background: Whether to include the background class (assumed to be in channel 0) in the score computation. Defaults to True. reduction: Define mode of reduction to the metrics. Available reduction modes: - {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, + {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``, ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction. weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform ground truth volume into a weight factor. Defaults to ``"square"``. diff --git a/tests/test_compute_generalized_dice.py b/tests/test_compute_generalized_dice.py index 448a17eaa5..51444adfdc 100644 --- a/tests/test_compute_generalized_dice.py +++ b/tests/test_compute_generalized_dice.py @@ -22,13 +22,13 @@ _device = "cuda:0" if torch.cuda.is_available() else "cpu" # keep background -TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1) +TEST_CASE_1 = [ # y (1, 1, 2, 2), y_pred (1, 1, 2, 2), expected out (1, 1) with compute_generalized_dice { "y_pred": torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]], device=_device), "y": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]], device=_device), "include_background": True, }, - [0.8], + [[0.8]], ] # remove background @@ -52,28 +52,7 @@ [0.583333, 0.333333], ] -# should return 0 for both cases -TEST_CASE_3 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (2, 3) - { - "y_pred": torch.tensor( - [ - [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]], - [[[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]], [[1.0, 1.0], [1.0, 1.0]]], - ] - ), - "y": torch.tensor( - [ - [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], - [[[0.0, 1.0], [1.0, 0.0]], [[1.0, 0.0], [0.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]]], - ] - ), - "include_background": True, - }, - [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], -] - -TEST_CASE_4 = [ - {"include_background": True, "reduction": "mean_batch"}, +TEST_CASE_3 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (1) with GeneralizedDiceScore { "y_pred": torch.tensor( [ @@ -87,12 +66,13 @@ [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], ] ), + "include_background": True, + "reduction": "mean", }, - [0.678571, 0.2, 0.333333], + [0.5454], ] -TEST_CASE_5 = [ - {"include_background": True, "reduction": "sum"}, +TEST_CASE_4 = [ # y (2, 3, 2, 2), y_pred (2, 3, 2, 2), expected out (1) with GeneralizedDiceScore { "y_pred": torch.tensor( [ @@ -106,32 +86,33 @@ [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], ] ), + "include_background": True, + "reduction": "sum", }, [1.045455], ] -TEST_CASE_6 = [ +TEST_CASE_5 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice {"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [[1.0000, 1.0000], [1.0000, 1.0000]], ] -TEST_CASE_7 = [ +TEST_CASE_6 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.ones((2, 2, 3, 3))}, [[0.0000, 0.0000], [0.0000, 0.0000]], ] -TEST_CASE_8 = [ +TEST_CASE_7 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice {"y": torch.ones((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [[0.0000, 0.0000], [0.0000, 0.0000]], ] -TEST_CASE_9 = [ +TEST_CASE_8 = [ # y (2, 2, 3, 3) y_pred (2, 2, 3, 3) expected out (2, 2) with compute_generalized_dice {"y": torch.zeros((2, 2, 3, 3)), "y_pred": torch.zeros((2, 2, 3, 3))}, [[1.0000, 1.0000], [1.0000, 1.0000]], ] -TEST_CASE_10 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2) - {"include_background": True, "reduction": "mean_channel"}, +TEST_CASE_9 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2) with GeneralizedDiceScore { "y_pred": torch.tensor( [ @@ -145,50 +126,113 @@ [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], ] ), + "include_background": True, + "reduction": "mean_channel", }, [0.545455, 0.545455], ] -class TestComputeGeneralizedDiceScore(unittest.TestCase): +TEST_CASE_10 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 3) with compute_generalized_dice and (3) with GeneralizedDiceScore "mean_batch" + { + "y_pred": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], + [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], + ] + ), + "y": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], + ] + ), + "include_background": True, + }, + [[0.857143, 0.0, 0.0], [0.5, 0.4, 0.666667]], +] + +TEST_CASE_11 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 1) with compute_generalized_dice (summed over classes) and (2) with GeneralizedDiceScore "mean_channel" + { + "y_pred": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 0.0]], [[0.0, 1.0], [0.0, 0.0]], [[0.0, 1.0], [1.0, 1.0]]], + [[[1.0, 0.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 1.0]], [[0.0, 1.0], [1.0, 0.0]]], + ] + ), + "y": torch.tensor( + [ + [[[1.0, 1.0], [1.0, 1.0]], [[0.0, 0.0], [0.0, 0.0]], [[0.0, 0.0], [0.0, 0.0]]], + [[[0.0, 0.0], [0.0, 1.0]], [[1.0, 1.0], [0.0, 0.0]], [[0.0, 0.0], [1.0, 0.0]]], + ] + ), + "include_background": True, + "sum_over_classes": True, + }, + [[0.545455], [0.545455]], +] + +class TestComputeGeneralizedDiceScore(unittest.TestCase): @parameterized.expand([TEST_CASE_1]) def test_device(self, input_data, _expected_value): + """ + Test if the result tensor is on the same device as the input tensor. + """ result = compute_generalized_dice(**input_data) np.testing.assert_equal(result.device, input_data["y_pred"].device) - # Functional part tests - @parameterized.expand([TEST_CASE_6, TEST_CASE_7, TEST_CASE_8, TEST_CASE_9]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7, TEST_CASE_8]) def test_value(self, input_data, expected_value): + """ + Test if the computed generalized dice score matches the expected value. + """ result = compute_generalized_dice(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) - # Functional part tests - @parameterized.expand([TEST_CASE_3]) - def test_nans(self, input_data, expected_value): - result = compute_generalized_dice(**input_data) - self.assertTrue(np.allclose(np.isnan(result.cpu().numpy()), expected_value)) - - # Samplewise tests - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_9]) def test_value_class(self, input_data, expected_value): - # same test as for compute_meandice - vals = {} - vals["y_pred"] = input_data.pop("y_pred") - vals["y"] = input_data.pop("y") + """ + Test if the GeneralizedDiceScore class computes the correct values. + """ + y_pred = input_data.pop("y_pred") + y = input_data.pop("y") generalized_dice_score = GeneralizedDiceScore(**input_data) - generalized_dice_score(**vals) + generalized_dice_score(y_pred=y_pred, y=y) result = generalized_dice_score.aggregate() np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) - # Aggregation tests - @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_10]) - def test_nans_class(self, params, input_data, expected_value): - generalized_dice_score = GeneralizedDiceScore(**params) - generalized_dice_score(**input_data) - result = generalized_dice_score.aggregate() + @parameterized.expand([TEST_CASE_10]) + def test_values_compare(self, input_data, expected_value): + """ + Compare the results of compute_generalized_dice function and GeneralizedDiceScore class. + """ + result = compute_generalized_dice(**input_data) np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + y_pred = input_data.pop("y_pred") + y = input_data.pop("y") + generalized_dice_score = GeneralizedDiceScore(**input_data, reduction="mean_batch") + generalized_dice_score(y_pred=y_pred, y=y) + result_class_mean = generalized_dice_score.aggregate() + np.testing.assert_allclose(result_class_mean.cpu().numpy(), np.mean(expected_value, axis=0), atol=1e-4) + + @parameterized.expand([TEST_CASE_11]) + def test_values_compare_sum_over_classes(self, input_data, expected_value): + """ + Compare the results when summing over classes between compute_generalized_dice function and GeneralizedDiceScore class. + """ + result = compute_generalized_dice(**input_data) + np.testing.assert_allclose(result.cpu().numpy(), expected_value, atol=1e-4) + + y_pred = input_data.pop("y_pred") + y = input_data.pop("y") + input_data.pop("sum_over_classes") + generalized_dice_score = GeneralizedDiceScore(**input_data, reduction="mean_channel") + generalized_dice_score(y_pred=y_pred, y=y) + result_class_mean = generalized_dice_score.aggregate() + np.testing.assert_allclose(result_class_mean.cpu().numpy(), np.mean(expected_value, axis=1), atol=1e-4) + if __name__ == "__main__": unittest.main() From 27a0c58b46fab6b0f7c8401f5e137ce447b3d726 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Fri, 6 Sep 2024 11:48:55 -0400 Subject: [PATCH 11/11] Fix flake8 error Signed-off-by: Suraj Pai --- tests/test_compute_generalized_dice.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_compute_generalized_dice.py b/tests/test_compute_generalized_dice.py index 51444adfdc..985a01e993 100644 --- a/tests/test_compute_generalized_dice.py +++ b/tests/test_compute_generalized_dice.py @@ -133,7 +133,8 @@ ] -TEST_CASE_10 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 3) with compute_generalized_dice and (3) with GeneralizedDiceScore "mean_batch" +TEST_CASE_10 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 3) with compute_generalized_dice + # and (3) with GeneralizedDiceScore "mean_batch" { "y_pred": torch.tensor( [ @@ -152,7 +153,8 @@ [[0.857143, 0.0, 0.0], [0.5, 0.4, 0.666667]], ] -TEST_CASE_11 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 1) with compute_generalized_dice (summed over classes) and (2) with GeneralizedDiceScore "mean_channel" +TEST_CASE_11 = [ # y (2, 3, 2, 2) y_pred (2, 3, 2, 2) expected out (2, 1) with compute_generalized_dice (summed over classes) + # and (2) with GeneralizedDiceScore "mean_channel" { "y_pred": torch.tensor( [