Skip to content

Commit

Permalink
Add updates from comments
Browse files Browse the repository at this point in the history
  • Loading branch information
surajpaib committed Aug 27, 2024
1 parent 61bbb60 commit 2081cd2
Showing 1 changed file with 26 additions and 9 deletions.
35 changes: 26 additions & 9 deletions monai/metrics/generalized_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch

from monai.metrics.utils import do_metric_reduction, ignore_background
from monai.utils import MetricReduction, Weight, look_up_option
from monai.utils import MetricReduction, Weight, deprecated_arg, deprecated_arg_default, look_up_option

from .metric import CumulativeIterationMetric

Expand Down Expand Up @@ -44,16 +44,27 @@ class GeneralizedDiceScore(CumulativeIterationMetric):
ValueError: When the `reduction` is not one of MetricReduction enum.
"""

@deprecated_arg_default(
"reduction",
MetricReduction.MEAN_BATCH,
MetricReduction.MEAN,
since="1.4.0",
replaced="1.4.0",
msg_suffix="Old versions computed `mean` when `mean_batch` was provided due to bug in reduction",
)
def __init__(
self, include_background: bool = True, reduction: str = "mean", weight_type: Weight | str = Weight.SQUARE
self,
include_background: bool = True,
reduction: str = MetricReduction.MEAN,
weight_type: Weight | str = Weight.SQUARE,
) -> None:
super().__init__()
self.include_background = include_background
self.reduction = look_up_option(reduction, MetricReduction)
self.weight_type = look_up_option(weight_type, Weight)
self.sum_over_labels = self.reduction == MetricReduction.SUM or self.reduction == MetricReduction.MEAN
self.sum_over_classes = self.reduction == MetricReduction.SUM or self.reduction == MetricReduction.MEAN

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Computes the Generalized Dice Score and returns a tensor with its per image values.
Expand All @@ -63,7 +74,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
y: Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.
Returns:
torch.Tensor: Per batch and per class Generalized Dice Score.
torch.Tensor: Generalized Dice Score averaged across batch and class
Raises:
ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
Expand All @@ -73,10 +84,16 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
y=y,
include_background=self.include_background,
weight_type=self.weight_type,
sum_over_labels=self.sum_over_labels,
sum_over_classes=self.sum_over_classes,
)

def aggregate(self) -> torch.Tensor:
@deprecated_arg(
"reduction",
since="1.4.0",
removed="1.7.0",
msg_suffix="Reduction will be ignored. Set reduction during init. as gen.dice needs it during compute",
)
def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor:
"""
Execute reduction logic for the output of `compute_generalized_dice`.
Expand All @@ -101,7 +118,7 @@ def compute_generalized_dice(
y: torch.Tensor,
include_background: bool = True,
weight_type: Weight | str = Weight.SQUARE,
sum_over_labels: bool = False,
sum_over_classes: bool = False,
) -> torch.Tensor:
"""
Computes the Generalized Dice Score and returns a tensor with its per image values.
Expand Down Expand Up @@ -158,7 +175,7 @@ def compute_generalized_dice(
b[infs] = torch.max(b)

# Compute the weighted numerator and denominator, summing along the class axis when sum_over_labels is True
if sum_over_labels:
if sum_over_classes:
numer = 2.0 * (intersection * w).sum(dim=1, keepdim=True)
denom = (denominator * w).sum(dim=1, keepdim=True)
y_pred_o = y_pred_o.sum(dim=-1, keepdim=True)
Expand Down

0 comments on commit 2081cd2

Please sign in to comment.