Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added spacing to surface distances calculations #6144

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2d449b4
added spacing to surface distances calculations
gasperpodobnik Mar 14, 2023
069efb0
Merge branch 'dev' into 6137-add-spacing-to-surface-distance-metrics
gasperpodobnik Mar 14, 2023
26dd22c
code formatting
gasperpodobnik Mar 14, 2023
753253c
deleted unnecessary whitespaces
gasperpodobnik Mar 14, 2023
b13b1d4
Merge branch '6137-add-spacing-to-surface-distance-metrics' of https:…
gasperpodobnik Mar 14, 2023
c7e6711
Merge branch 'dev' into 6137-add-spacing-to-surface-distance-metrics
gasperpodobnik Mar 14, 2023
4b15604
formatted with black
gasperpodobnik Mar 14, 2023
8585dad
Merge branch '6137-add-spacing-to-surface-distance-metrics' of https:…
gasperpodobnik Mar 14, 2023
5d7bb44
minor
gasperpodobnik Mar 14, 2023
65436a9
DCO Remediation Commit for gasperp <gasper.podobnik@gmail.com>
gasperpodobnik Mar 14, 2023
6cf33bf
minor
gasperpodobnik Mar 14, 2023
448a211
DCO Remediation Commit for gasperp <gasper.podobnik@gmail.com>
gasperpodobnik Mar 14, 2023
bb53664
removed checker for length of spacing parameter
gasperpodobnik Mar 15, 2023
5c8a4a3
spacing parameter can now be passed to metric call via kwargs
gasperpodobnik Mar 27, 2023
f963ea0
added tests that include spacing parameter
gasperpodobnik Mar 27, 2023
99a73df
fixed isort error
gasperpodobnik Mar 27, 2023
66383aa
minor fix
gasperpodobnik Mar 27, 2023
f30e355
minor
gasperpodobnik Mar 27, 2023
20b9c02
fixed formatting
gasperpodobnik Mar 27, 2023
ffc38ff
Merge branch 'dev' of https://github.com/Project-MONAI/MONAI into 613…
gasperpodobnik Mar 27, 2023
aa2196e
minor
gasperpodobnik Mar 27, 2023
ed60404
fixed docs issue
gasperpodobnik Mar 27, 2023
378d023
fixed line too long issue
gasperpodobnik Mar 27, 2023
5f71c71
fixed mypy issues
gasperpodobnik Apr 12, 2023
23d7fea
Merge branch 'dev' of https://github.com/Project-MONAI/MONAI into 613…
gasperpodobnik Apr 12, 2023
280cd9e
minor isort fix
gasperpodobnik Apr 12, 2023
658259a
minor black fix
gasperpodobnik Apr 12, 2023
f992e31
minor doc improvement
gasperpodobnik Apr 12, 2023
de816d2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2023
803d129
Merge branch 'dev' into 6137-add-spacing-to-surface-distance-metrics
wyli Apr 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 46 additions & 6 deletions monai/metrics/hausdorff_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@
from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Any

import numpy as np
import torch

from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background
from monai.metrics.utils import (
do_metric_reduction,
get_mask_edges,
get_surface_distance,
ignore_background,
prepare_spacing,
)
from monai.utils import MetricReduction, convert_data_type

from .metric import CumulativeIterationMetric
Expand Down Expand Up @@ -70,21 +78,32 @@ def __init__(
self.reduction = reduction
self.get_not_nans = get_not_nans

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override]
"""
Args:
y_pred: input data to compute, typical segmentation model output.
It must be one-hot format and first dim is batch, example shape: [16, 3, 32, 32]. The values
should be binarized.
y: ground truth to compute the distance. It must be one-hot format and first dim is batch.
The values should be binarized.
kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric.
``spacing``: spacing of pixel (or voxel). This parameter is relevant only
if ``distance_metric`` is set to ``"euclidean"``.
If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,
the length of the sequence must be equal to the image dimensions.
This spacing will be used for all images in the batch.
If a sequence of sequences, the length of the outer sequence must be equal to the batch size.
If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
for all images in batch. Defaults to ``None``.

Raises:
ValueError: when `y_pred` has less than three dimensions.
"""
dims = y_pred.ndimension()
if dims < 3:
raise ValueError("y_pred should have at least three dimensions.")

# compute (BxC) for each channel for each batch
return compute_hausdorff_distance(
y_pred=y_pred,
Expand All @@ -93,6 +112,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
distance_metric=self.distance_metric,
percentile=self.percentile,
directed=self.directed,
spacing=kwargs.get("spacing"),
)

def aggregate(
Expand Down Expand Up @@ -123,6 +143,7 @@ def compute_hausdorff_distance(
distance_metric: str = "euclidean",
percentile: float | None = None,
directed: bool = False,
spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None,
) -> torch.Tensor:
"""
Compute the Hausdorff distance.
Expand All @@ -141,6 +162,13 @@ def compute_hausdorff_distance(
percentile of the Hausdorff Distance rather than the maximum result will be achieved.
Defaults to ``None``.
directed: whether to calculate directed Hausdorff distance. Defaults to ``False``.
spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``"euclidean"``.
If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,
the length of the sequence must be equal to the image dimensions. This spacing will be used for all images in the batch.
If a sequence of sequences, the length of the outer sequence must be equal to the batch size.
If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
for all images in batch. Defaults to ``None``.
"""

if not include_background:
Expand All @@ -153,30 +181,42 @@ def compute_hausdorff_distance(

batch_size, n_class = y_pred.shape[:2]
hd = np.empty((batch_size, n_class))

img_dim = y_pred.ndim - 2
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)

for b, c in np.ndindex(batch_size, n_class):
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c])
if not np.any(edges_gt):
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
if not np.any(edges_pred):
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")

distance_1 = compute_percent_hausdorff_distance(edges_pred, edges_gt, distance_metric, percentile)
distance_1 = compute_percent_hausdorff_distance(
edges_pred, edges_gt, distance_metric, percentile, spacing_list[b]
)
if directed:
hd[b, c] = distance_1
else:
distance_2 = compute_percent_hausdorff_distance(edges_gt, edges_pred, distance_metric, percentile)
distance_2 = compute_percent_hausdorff_distance(
edges_gt, edges_pred, distance_metric, percentile, spacing_list[b]
)
hd[b, c] = max(distance_1, distance_2)
return convert_data_type(hd, output_type=torch.Tensor, device=y_pred.device, dtype=torch.float)[0]


def compute_percent_hausdorff_distance(
edges_pred: np.ndarray, edges_gt: np.ndarray, distance_metric: str = "euclidean", percentile: float | None = None
edges_pred: np.ndarray,
edges_gt: np.ndarray,
distance_metric: str = "euclidean",
percentile: float | None = None,
spacing: int | float | np.ndarray | Sequence[int | float] | None = None,
) -> float:
"""
This function is used to compute the directed Hausdorff distance.
"""

surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric)
surface_distance = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing)

# for both pred and gt do not have foreground
if surface_distance.shape == (0,):
Expand Down
4 changes: 3 additions & 1 deletion monai/metrics/loss_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from __future__ import annotations

from typing import Any

import torch
from torch.nn.modules.loss import _Loss

Expand Down Expand Up @@ -92,7 +94,7 @@ def aggregate(
f, not_nans = do_metric_reduction(data, reduction or self.reduction)
return (f, not_nans) if self.get_not_nans else f

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None) -> TensorOrList:
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None, **kwargs: Any) -> TensorOrList:
"""
Input `y_pred` is compared with ground truth `y`.
Both `y_pred` and `y` are expected to be a batch-first Tensor (BC[HWD]).
Expand Down
23 changes: 14 additions & 9 deletions monai/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class IterationMetric(Metric):
"""

def __call__(
self, y_pred: TensorOrList, y: TensorOrList | None = None
self, y_pred: TensorOrList, y: TensorOrList | None = None, **kwargs: Any
) -> torch.Tensor | Sequence[torch.Tensor | Sequence[torch.Tensor]]:
"""
Execute basic computation for model prediction `y_pred` and ground truth `y` (optional).
Expand All @@ -60,6 +60,7 @@ def __call__(
or a `batch-first` Tensor.
y: the ground truth to compute, must be a list of `channel-first` Tensor
or a `batch-first` Tensor.
kwargs: additional parameters for specific metric computation logic (e.g. ``spacing`` for SurfaceDistanceMetric, etc.).

Returns:
The computed metric values at the iteration level.
Expand All @@ -69,15 +70,15 @@ def __call__(
"""
# handling a list of channel-first data
if isinstance(y_pred, (list, tuple)) or isinstance(y, (list, tuple)):
return self._compute_list(y_pred, y)
return self._compute_list(y_pred, y, **kwargs)
# handling a single batch-first data
if isinstance(y_pred, torch.Tensor):
y_ = y.detach() if isinstance(y, torch.Tensor) else None
return self._compute_tensor(y_pred.detach(), y_)
return self._compute_tensor(y_pred.detach(), y_, **kwargs)
raise ValueError("y_pred or y must be a list/tuple of `channel-first` Tensors or a `batch-first` Tensor.")

def _compute_list(
self, y_pred: TensorOrList, y: TensorOrList | None = None
self, y_pred: TensorOrList, y: TensorOrList | None = None, **kwargs: Any
) -> torch.Tensor | list[torch.Tensor | Sequence[torch.Tensor]]:
"""
Execute the metric computation for `y_pred` and `y` in a list of "channel-first" tensors.
Expand All @@ -93,9 +94,12 @@ def _compute_list(
Note: subclass may enhance the operation to have multi-thread support.
"""
if y is not None:
ret = [self._compute_tensor(p.detach().unsqueeze(0), y_.detach().unsqueeze(0)) for p, y_ in zip(y_pred, y)]
ret = [
self._compute_tensor(p.detach().unsqueeze(0), y_.detach().unsqueeze(0), **kwargs)
for p, y_ in zip(y_pred, y)
]
else:
ret = [self._compute_tensor(p_.detach().unsqueeze(0), None) for p_ in y_pred]
ret = [self._compute_tensor(p_.detach().unsqueeze(0), None, **kwargs) for p_ in y_pred]

# concat the list of results (e.g. a batch of evaluation scores)
if isinstance(ret[0], torch.Tensor):
Expand All @@ -106,7 +110,7 @@ def _compute_list(
return ret

@abstractmethod
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None) -> TensorOrList:
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor | None = None, **kwargs: Any) -> TensorOrList:
"""
Computation logic for `y_pred` and `y` of an iteration, the data should be "batch-first" Tensors.
A subclass should implement its own computation logic.
Expand Down Expand Up @@ -318,7 +322,7 @@ class CumulativeIterationMetric(Cumulative, IterationMetric):
"""

def __call__(
self, y_pred: TensorOrList, y: TensorOrList | None = None
self, y_pred: TensorOrList, y: TensorOrList | None = None, **kwargs: Any
) -> torch.Tensor | Sequence[torch.Tensor | Sequence[torch.Tensor]]:
"""
Execute basic computation for model prediction and ground truth.
Expand All @@ -331,12 +335,13 @@ def __call__(
or a `batch-first` Tensor.
y: the ground truth to compute, must be a list of `channel-first` Tensor
or a `batch-first` Tensor.
kwargs: additional parameters for specific metric computation logic (e.g. ``spacing`` for SurfaceDistanceMetric, etc.).

Returns:
The computed metric values at the iteration level. The output shape should be
a `batch-first` tensor (BC[HWD]) or a list of `batch-first` tensors.
"""
ret = super().__call__(y_pred=y_pred, y=y)
ret = super().__call__(y_pred=y_pred, y=y, **kwargs)
if isinstance(ret, (tuple, list)):
self.extend(*ret)
else:
Expand Down
42 changes: 38 additions & 4 deletions monai/metrics/surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,19 @@
from __future__ import annotations

import warnings
from collections.abc import Sequence
from typing import Any

import numpy as np
import torch

from monai.metrics.utils import do_metric_reduction, get_mask_edges, get_surface_distance, ignore_background
from monai.metrics.utils import (
do_metric_reduction,
get_mask_edges,
get_surface_distance,
ignore_background,
prepare_spacing,
)
from monai.utils import MetricReduction, convert_data_type

from .metric import CumulativeIterationMetric
Expand Down Expand Up @@ -67,13 +75,23 @@ def __init__(
self.reduction = reduction
self.get_not_nans = get_not_nans

def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor, **kwargs: Any) -> torch.Tensor: # type: ignore[override]
r"""
Args:
y_pred: Predicted segmentation, typically segmentation model output.
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
y: Reference segmentation.
It must be a one-hot encoded, batch-first tensor [B,C,H,W].
kwargs: additional parameters, e.g. ``spacing`` should be passed to correctly compute the metric.
``spacing``: spacing of pixel (or voxel). This parameter is relevant only
if ``distance_metric`` is set to ``"euclidean"``.
If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,
the length of the sequence must be equal to the image dimensions.
This spacing will be used for all images in the batch.
If a sequence of sequences, the length of the outer sequence must be equal to the batch size.
If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
for all images in batch. Defaults to ``None``.

Returns:
Pytorch Tensor of shape [B,C], containing the NSD values :math:`\operatorname {NSD}_{b,c}` for each batch
Expand All @@ -85,6 +103,7 @@ def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
class_thresholds=self.class_thresholds,
include_background=self.include_background,
distance_metric=self.distance_metric,
spacing=kwargs.get("spacing"),
)

def aggregate(
Expand Down Expand Up @@ -117,6 +136,7 @@ def compute_surface_dice(
class_thresholds: list[float],
include_background: bool = False,
distance_metric: str = "euclidean",
spacing: int | float | np.ndarray | Sequence[int | float | np.ndarray | Sequence[int | float]] | None = None,
) -> torch.Tensor:
r"""
This function computes the (Normalized) Surface Dice (NSD) between the two tensors `y_pred` (referred to as
Expand Down Expand Up @@ -167,6 +187,13 @@ def compute_surface_dice(
distance_metric: The metric used to compute surface distances.
One of [``"euclidean"``, ``"chessboard"``, ``"taxicab"``].
Defaults to ``"euclidean"``.
spacing: spacing of pixel (or voxel). This parameter is relevant only if ``distance_metric`` is set to ``"euclidean"``.
If a single number, isotropic spacing with that value is used for all images in the batch. If a sequence of numbers,
the length of the sequence must be equal to the image dimensions. This spacing will be used for all images in the batch.
If a sequence of sequences, the length of the outer sequence must be equal to the batch size.
If inner sequence has length 1, isotropic spacing with that value is used for all images in the batch,
else the inner sequence length must be equal to the image dimensions. If ``None``, spacing of unity is used
for all images in batch. Defaults to ``None``.

Raises:
ValueError: If `y_pred` and/or `y` are not PyTorch tensors.
Expand Down Expand Up @@ -219,15 +246,22 @@ def compute_surface_dice(

nsd = np.empty((batch_size, n_class))

img_dim = y_pred.ndim - 2
spacing_list = prepare_spacing(spacing=spacing, batch_size=batch_size, img_dim=img_dim)

for b, c in np.ndindex(batch_size, n_class):
(edges_pred, edges_gt) = get_mask_edges(y_pred[b, c], y[b, c], crop=False)
if not np.any(edges_gt):
warnings.warn(f"the ground truth of class {c} is all 0, this may result in nan/inf distance.")
if not np.any(edges_pred):
warnings.warn(f"the prediction of class {c} is all 0, this may result in nan/inf distance.")

distances_pred_gt = get_surface_distance(edges_pred, edges_gt, distance_metric=distance_metric)
distances_gt_pred = get_surface_distance(edges_gt, edges_pred, distance_metric=distance_metric)
distances_pred_gt = get_surface_distance(
edges_pred, edges_gt, distance_metric=distance_metric, spacing=spacing_list[b]
)
distances_gt_pred = get_surface_distance(
edges_gt, edges_pred, distance_metric=distance_metric, spacing=spacing_list[b]
)

boundary_complete = len(distances_pred_gt) + len(distances_gt_pred)
boundary_correct = np.sum(distances_pred_gt <= class_thresholds[c]) + np.sum(
Expand Down
Loading