Skip to content

Commit

Permalink
4103 enhances surface Dice to use subvoxel borders (Project-MONAI#6681)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#4103

### Description
- considers spacing and subvoxel borders when computing surface Dice
- reimplemented
http://medicaldecathlon.com/files/Surface_distance_based_measures.ipynb
using pytorch API
- there's possibility of speed up once
Project-MONAI#1332 is addressed


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <wenqil@nvidia.com>
  • Loading branch information
wyli authored Jul 5, 2023
1 parent 18daff1 commit 922b11e
Show file tree
Hide file tree
Showing 4 changed files with 541 additions and 124 deletions.
79 changes: 46 additions & 33 deletions monai/metrics/surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ class SurfaceDiceMetric(CumulativeIterationMetric):
Computes the Normalized Surface Dice (NSD) for each batch sample and class of
predicted segmentations `y_pred` and corresponding reference segmentations `y` according to equation :eq:`nsd`.
This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images.
Be aware that the computation of boundaries is different from DeepMind's implementation
https://github.com/deepmind/surface-distance. In this implementation, the length/area of a segmentation boundary is
Be aware that by default (`use_subvoxels=False`), the computation of boundaries is different from DeepMind's
mplementation https://github.com/deepmind/surface-distance.
In this implementation, the length/area of a segmentation boundary is
interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary
depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).
This issue is discussed here: https://github.com/Project-MONAI/MONAI/issues/4103.
Expand Down Expand Up @@ -86,7 +87,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
y: Reference segmentation.
It must be a one-hot encoded, batch-first tensor [B,C,H,W] or [B,C,H,W,D].
kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric.
kwargs: additional parameters: ``spacing`` should be passed to correctly compute the metric.
``spacing``: spacing of pixel (or voxel). This parameter is relevant only
if ``distance_metric`` is set to ``"euclidean"``.
If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,
Expand All @@ -96,6 +97,8 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
for all images in batch. Defaults to ``None``.
use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``.
Returns:
Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch
Expand All @@ -108,6 +111,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any)
include_background=self.include_background,
distance_metric=self.distance_metric,
spacing=kwargs.get("spacing"),
use_subvoxels=kwargs.get("use_subvoxels", False),
)

def aggregate(
Expand Down Expand Up @@ -141,13 +145,14 @@ def compute_surface_dice(
include_background: bool = False,
distance_metric: str = "euclidean",
spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None,
use_subvoxels: bool = False,
) -> torch.Tensor:
r"""
This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as
:math:`\hat{Y}`) and `y` (referred to as :math:`Y`). This metric determines which fraction of a segmentation
boundary is correctly predicted. A boundary element is considered correctly predicted if the closest distance to the
reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation in
pixels. The NSD is bounded between 0 and 1.
reference boundary is smaller than or equal to the specified threshold related to the acceptable amount of deviation
in pixels. The NSD is bounded between 0 and 1.
This implementation supports multi-class tasks with an individual threshold :math:`\tau_c` for each class :math:`c`.
The class-specific NSD for batch index :math:`b`, :math:`\operatorname {NSD}_{b,c}`, is computed using the function:
Expand All @@ -159,24 +164,23 @@ def compute_surface_dice(
:label: nsd
with :math:`\mathcal{D}_{Y_{b,c}}` and :math:`\mathcal{D}_{\hat{Y}_{b,c}}` being two sets of nearest-neighbor
distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference segmentation
boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and
distances. :math:`\mathcal{D}_{Y_{b,c}}` is computed from the predicted segmentation boundary towards the reference
segmentation boundary and vice-versa for :math:`\mathcal{D}_{\hat{Y}_{b,c}}`. :math:`\mathcal{D}_{Y_{b,c}}^{'}` and
:math:`\mathcal{D}_{\hat{Y}_{b,c}}^{'}` refer to the subsets of distances that are smaller or equal to the
acceptable distance :math:`\tau_c`:
.. math::
\mathcal{D}_{Y_{b,c}}^{'} = \{ d \in \mathcal{D}_{Y_{b,c}} \, | \, d \leq \tau_c \}.
In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation, a nan value
will be returned for this class. In the case of a class being present in only one of predicted segmentation or
reference segmentation, the class NSD will be 0.
In the case of a class neither being present in the predicted segmentation, nor in the reference segmentation,
a nan value will be returned for this class. In the case of a class being present in only one of predicted
segmentation or reference segmentation, the class NSD will be 0.
This implementation is based on https://arxiv.org/abs/2111.05408 and supports 2D and 3D images.
Be aware that the computation of boundaries is different from DeepMind's implementation
https://github.com/deepmind/surface-distance. In this implementation, the length of a segmentation boundary is
interpreted as the number of its edge pixels. In DeepMind's implementation, the length of a segmentation boundary
depends on the local neighborhood (cf. https://arxiv.org/abs/1809.04430).
The computation of boundaries follows DeepMind's implementation
https://github.com/deepmind/surface-distance when `use_subvoxels=True`; Otherwise the length of a segmentation
boundary is interpreted as the number of its edge pixels.
Args:
y_pred: Predicted segmentation, typically segmentation model output.
Expand All @@ -198,6 +202,7 @@ def compute_surface_dice(
If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
for all images in batch. Defaults to ``None``.
use_subvoxels: Whether to use subvoxel distances. Defaults to ``False``.
Raises:
ValueError: If `y_pred` and/or `y` are not PyTorch tensors.
Expand Down Expand Up @@ -227,11 +232,6 @@ def compute_surface_dice(
f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)."
)

if not torch.all(y_pred.byte() == y_pred) or not torch.all(y.byte() == y):
raise ValueError("y_pred and y should be binarized tensors (e.g. torch.int64).")
if torch.any(y_pred > 1) or torch.any(y > 1):
raise ValueError("y_pred and y should be one-hot encoded.")

y = y.float()
y_pred = y_pred.float()

Expand All @@ -254,24 +254,37 @@ def compute_surface_dice(
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)

for b, c in np.ndindex(batch_size, n_class):
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=False)
if not use_subvoxels:
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=True)
distances_pred_gt = get_surface_distance(
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]
)
distances_gt_pred = get_surface_distance(
edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]
)

boundary_complete = len(distances_pred_gt) + len(distances_gt_pred)
boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum(
distances_gt_pred <= class_thresholds[c]
)
else:
_spacing = spacing_list[b] if spacing_list[b] is not None else [1] * img_dim
areas_pred: np.ndarray
areas_gt: np.ndarray
edges_pred, edges_gt, areas_pred, areas_gt = get_mask_edges( # type: ignore
y_pred[b, c], y[b, c], crop=True, spacing=_spacing # type: ignore
)
dist_pred_to_gt = get_surface_distance(edges_pred, edges_gt, distance_metric, spacing=spacing_list[b])
dist_gt_to_pred = get_surface_distance(edges_gt, edges_pred, distance_metric, spacing=spacing_list[b])
areas_gt, areas_pred = areas_gt[edges_gt], areas_pred[edges_pred]
boundary_complete = areas_gt.sum() + areas_pred.sum()
gt_true = areas_gt[dist_gt_to_pred <= class_thresholds[c]].sum() if len(areas_gt) > 0 else 0.0
pred_true = areas_pred[dist_pred_to_gt <= class_thresholds[c]].sum() if len(areas_pred) > 0 else 0.0
boundary_correct = gt_true + pred_true
if not np.any(edges_gt):
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
if not np.any(edges_pred):
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")

distances_pred_gt = get_surface_distance(
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]
)
distances_gt_pred = get_surface_distance(
edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]
)

boundary_complete = len(distances_pred_gt) + len(distances_gt_pred)
boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum(
distances_gt_pred <= class_thresholds[c]
)

if boundary_complete == 0:
# the class is neither present in the prediction, nor in the reference segmentation
nsd[b, c] = np.nan
Expand Down
Loading

0 comments on commit 922b11e

Please sign in to comment.