From a21bb20a2599b04f7ffacd01921e1d627a1ef5c4 Mon Sep 17 00:00:00 2001 From: kaibo Date: Wed, 29 Nov 2023 15:19:04 -0500 Subject: [PATCH 1/4] Add diffusion loss Signed-off-by: kaibo --- monai/losses/deform.py | 71 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index dd03a8eb3d..41defb8fda 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -116,3 +116,74 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') return energy + + +class DiffusionLoss(_Loss): + """ + Calculate the diffusion based on first-order differentiation of pred using central finite difference. + For the original paper, please refer to + VoxelMorph: A Learning Framework for Deformable Medical Image Registration, + Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca + IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231. + + Adapted from: + VoxelMorph (https://github.com/voxelmorph/voxelmorph) + """ + + def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None: + """ + Args: + normalize: + Whether to divide out spatial sizes in order to make the computation roughly + invariant to image scale (i.e. vector field sampling resolution). Defaults to False. + reduction: {``"none"``, ``"mean"``, ``"sum"``} + Specifies the reduction to apply to the output. Defaults to ``"mean"``. + + - ``"none"``: no reduction will be applied. + - ``"mean"``: the sum of the output will be divided by the number of elements in the output. + - ``"sum"``: the output will be summed. + """ + super().__init__(reduction=LossReduction(reduction).value) + self.normalize = normalize + + def forward(self, pred: torch.Tensor) -> torch.Tensor: + """ + Args: + pred: the shape should be BCH(WD) + + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + + """ + if pred.ndim not in [3, 4, 5]: + raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + for i in range(pred.ndim - 2): + if pred.shape[-i - 1] <= 4: + raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}") + if pred.shape[1] != pred.ndim - 2: + raise ValueError( + f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim - 2}" + ) + + # first order gradient + first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] + + # spatial dimensions in a shape suited for broadcasting below + if self.normalize: + spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) + + diffusion = torch.tensor(0) + for dim_1, g in enumerate(first_order_gradient): + dim_1 += 2 + if self.normalize: + g *= pred.shape[dim_1] / spatial_dims + diffusion = diffusion + g ** 2 + + if self.reduction == LossReduction.MEAN.value: + diffusion = torch.mean(diffusion) # the batch and channel average + elif self.reduction == LossReduction.SUM.value: + diffusion = torch.sum(diffusion) # sum over the batch and channel dims + elif self.reduction != LossReduction.NONE.value: + raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].') + + return diffusion From d80790ac70c16fc26166cb8c1200095eb265b02a Mon Sep 17 00:00:00 2001 From: kaibo Date: Wed, 29 Nov 2023 21:42:40 -0500 Subject: [PATCH 2/4] Add unittest with justifications, add docs, fix style Signed-off-by: kaibo --- docs/source/losses.rst | 5 ++ monai/losses/deform.py | 16 ++--- tests/test_diffusion_loss.py | 116 +++++++++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 8 deletions(-) create mode 100644 tests/test_diffusion_loss.py diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 5d488afbb3..f98aa6ceb2 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -86,6 +86,11 @@ Registration Losses .. autoclass:: BendingEnergyLoss :members: +`DiffusionLoss` +~~~~~~~~~~~~~~~ +.. autoclass:: DiffusionLoss + :members: + `LocalNormalizedCrossCorrelationLoss` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: LocalNormalizedCrossCorrelationLoss diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 41defb8fda..02d396dcd2 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -148,18 +148,18 @@ def __init__(self, normalize: bool = False, reduction: LossReduction | str = Los def forward(self, pred: torch.Tensor) -> torch.Tensor: """ - Args: - pred: the shape should be BCH(WD) + Args: + pred: the shape should be BCH(WD) - Raises: - ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + Raises: + ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. - """ + """ if pred.ndim not in [3, 4, 5]: raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") for i in range(pred.ndim - 2): - if pred.shape[-i - 1] <= 4: - raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}") + if pred.shape[-i - 1] <= 2: + raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}") if pred.shape[1] != pred.ndim - 2: raise ValueError( f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim - 2}" @@ -177,7 +177,7 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: dim_1 += 2 if self.normalize: g *= pred.shape[dim_1] / spatial_dims - diffusion = diffusion + g ** 2 + diffusion = diffusion + g**2 if self.reduction == LossReduction.MEAN.value: diffusion = torch.mean(diffusion) # the batch and channel average diff --git a/tests/test_diffusion_loss.py b/tests/test_diffusion_loss.py new file mode 100644 index 0000000000..05dfab95fb --- /dev/null +++ b/tests/test_diffusion_loss.py @@ -0,0 +1,116 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch +from parameterized import parameterized + +from monai.losses.deform import DiffusionLoss + +device = "cuda" if torch.cuda.is_available() else "cpu" + +TEST_CASES = [ + # all first partials are zero, so the diffusion loss is also zero + [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], + # all first partials are one, so the diffusion loss is also one + [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0], + # before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67 + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 56.0 / 3.0, + ], + # same as the previous case + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 56.0 / 3.0, + ], + # same as the previous case + [{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], + # we have shown in the demo notebook that + # diffusion loss is scale-invariant when the all axes have the same resolution + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 56.0 / 3.0, + ], + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 56.0 / 3.0, + ], + [{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0], + # for the following case, consider the following 2D matrix: + # tensor([[[[0, 1, 2], + # [1, 2, 3], + # [2, 3, 4], + # [3, 4, 5], + # [4, 5, 6]], + # [[0, 1, 2], + # [1, 2, 3], + # [2, 3, 4], + # [3, 4, 5], + # [4, 5, 6]]]]) + # the first partials wrt x are all ones, and so are the first partials wrt y + # the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2 + [{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0], + # consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook, + # the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y + # the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689 + [ + {"normalize": True}, + {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, + (1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0, + ], +] + + +class TestDiffusionLoss(unittest.TestCase): + @parameterized.expand(TEST_CASES) + def test_shape(self, input_param, input_data, expected_val): + result = DiffusionLoss(**input_param).forward(**input_data) + np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5) + + def test_ill_shape(self): + loss = DiffusionLoss() + # not in 3-d, 4-d, 5-d + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 3), device=device)) + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 2, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 5, 2, 5))) + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): + loss.forward(torch.ones((1, 3, 5, 5, 2))) + + # number of vector components unequal to number of spatial dims + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + + def test_ill_opts(self): + pred = torch.rand(1, 3, 5, 5, 5).to(device=device) + with self.assertRaisesRegex(ValueError, ""): + DiffusionLoss(reduction="unknown")(pred) + with self.assertRaisesRegex(ValueError, ""): + DiffusionLoss(reduction=None)(pred) + + +if __name__ == "__main__": + unittest.main() From b48c5fc69f08e4f9b3703450c5d769405275145c Mon Sep 17 00:00:00 2001 From: kaibo Date: Wed, 29 Nov 2023 21:49:03 -0500 Subject: [PATCH 3/4] Add diffusion loss to init Signed-off-by: kaibo --- monai/losses/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index d734a9d44d..92898c81ca 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -14,7 +14,7 @@ from .adversarial_loss import PatchAdversarialLoss from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss from .contrastive import ContrastiveLoss -from .deform import BendingEnergyLoss +from .deform import BendingEnergyLoss, DiffusionLoss from .dice import ( Dice, DiceCELoss, From c97651b31df405c26db1bbd112d789a80fcc81b3 Mon Sep 17 00:00:00 2001 From: kaibo Date: Mon, 4 Dec 2023 09:34:58 -0500 Subject: [PATCH 4/4] Add docstrings Signed-off-by: kaibo --- monai/losses/deform.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 02d396dcd2..129abeedd2 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -149,10 +149,17 @@ def __init__(self, normalize: bool = False, reduction: LossReduction | str = Los def forward(self, pred: torch.Tensor) -> torch.Tensor: """ Args: - pred: the shape should be BCH(WD) + pred: + Predicted dense displacement field (DDF) with shape BCH[WD], + where C is the number of spatial dimensions. + Note that diffusion loss can only be calculated + when the sizes of the DDF along all spatial dimensions are greater than 2. Raises: ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. + ValueError: When ``pred`` is not 3-d, 4-d or 5-d. + ValueError: When any spatial dimension of ``pred`` has size less than or equal to 2. + ValueError: When the number of channels of ``pred`` does not match the number of spatial dimensions. """ if pred.ndim not in [3, 4, 5]: @@ -162,7 +169,8 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}") if pred.shape[1] != pred.ndim - 2: raise ValueError( - f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim - 2}" + f"Number of vector components, i.e. number of channels of the input DDF, {pred.shape[1]}, " + f"does not match number of spatial dimensions, {pred.ndim - 2}" ) # first order gradient @@ -176,6 +184,9 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: for dim_1, g in enumerate(first_order_gradient): dim_1 += 2 if self.normalize: + # We divide the partial derivative for each vector component at each voxel by the spatial size + # corresponding to that component relative to the spatial size of the vector component with respect + # to which the partial derivative is taken. g *= pred.shape[dim_1] / spatial_dims diffusion = diffusion + g**2