From 8bb0fae3beebeda5f049cf87d936e2f9aef8caa4 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 25 Jul 2023 22:57:41 +0100 Subject: [PATCH 1/2] Revert "fix `GeneralizedDiceLoss` (#5468)" This reverts commit e03ecd4f474bec099af56e415dcbfb178164c080. Signed-off-by: Wenqi Li --- monai/losses/dice.py | 5 +++-- tests/test_generalized_dice_loss.py | 12 ++++++------ 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 088ad50efd..0f34b3ec5f 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -360,8 +360,9 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1) w = w + infs * max_values - numer = 2.0 * (intersection * w) + self.smooth_nr - denom = (denominator * w) + self.smooth_dr + final_reduce_dim = 0 if self.batch else 1 + numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr + denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr f: torch.Tensor = 1.0 - (numer / denom) if self.reduction == LossReduction.MEAN.value: diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index b8256a41a9..d8ba496d03 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -48,7 +48,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - 0.435035, + 0.469964, ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -56,7 +56,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - 0.3837, + 0.414507, ], [ # shape: (2, 2, 3), (2, 1, 3) { @@ -71,7 +71,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - 1.5348, + 0.829015, ], [ # shape: (2, 2, 3), (2, 1, 3) { @@ -86,7 +86,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - [[[0.210949], [0.295351]], [[0.599976], [0.428522]]], + [[[0.273476]], [[0.555539]]], ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 1e-8, "smooth_dr": 1e-8}, @@ -114,7 +114,7 @@ "input": torch.tensor([[[0.0, 10.0, 10.0, 10.0], [10.0, 0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1, 1, 0, 0]]]), }, - 0.26669, + 0.250023, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) {"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -136,7 +136,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - -8.55485, + -0.097833, ], ] From b48c53f38d275ce0a146c5369ec0da69cdfd99c4 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 25 Jul 2023 23:01:01 +0100 Subject: [PATCH 2/2] update docstring Signed-off-by: Wenqi Li --- monai/losses/dice.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 0f34b3ec5f..4f53e3a7d8 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -268,6 +268,7 @@ def __init__( smooth_dr: a small constant added to the denominator to avoid nan. batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, intersection over union is computed from each item in the batch. + If True, the class-weighted intersection and union areas are first summed across the batches. Raises: TypeError: When ``other_act`` is not an ``Optional[Callable]``.