Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

7263 add diffusion loss #7272

Merged
merged 9 commits into from
Dec 5, 2023
5 changes: 5 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ Registration Losses
.. autoclass:: BendingEnergyLoss
:members:

`DiffusionLoss`
~~~~~~~~~~~~~~~
.. autoclass:: DiffusionLoss
:members:

`LocalNormalizedCrossCorrelationLoss`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: LocalNormalizedCrossCorrelationLoss
Expand Down
2 changes: 1 addition & 1 deletion monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from .adversarial_loss import PatchAdversarialLoss
from .cldice import SoftclDiceLoss, SoftDiceclDiceLoss
from .contrastive import ContrastiveLoss
from .deform import BendingEnergyLoss
from .deform import BendingEnergyLoss, DiffusionLoss
from .dice import (
Dice,
DiceCELoss,
Expand Down
71 changes: 71 additions & 0 deletions monai/losses/deform.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,74 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

return energy


class DiffusionLoss(_Loss):
"""
Calculate the diffusion based on first-order differentiation of pred using central finite difference.
For the original paper, please refer to
VoxelMorph: A Learning Framework for Deformable Medical Image Registration,
Guha Balakrishnan, Amy Zhao, Mert R. Sabuncu, John Guttag, Adrian V. Dalca
IEEE TMI: Transactions on Medical Imaging. 2019. eprint arXiv:1809.05231.

Adapted from:
VoxelMorph (https://github.com/voxelmorph/voxelmorph)
"""

def __init__(self, normalize: bool = False, reduction: LossReduction | str = LossReduction.MEAN) -> None:
"""
Args:
normalize:
Whether to divide out spatial sizes in order to make the computation roughly
invariant to image scale (i.e. vector field sampling resolution). Defaults to False.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

- ``"none"``: no reduction will be applied.
- ``"mean"``: the sum of the output will be divided by the number of elements in the output.
- ``"sum"``: the output will be summed.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.normalize = normalize

def forward(self, pred: torch.Tensor) -> torch.Tensor:
"""
Args:
pred: the shape should be BCH(WD)
kvttt marked this conversation as resolved.
Show resolved Hide resolved

Raises:
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
kvttt marked this conversation as resolved.
Show resolved Hide resolved

"""
if pred.ndim not in [3, 4, 5]:
raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}")
for i in range(pred.ndim - 2):
if pred.shape[-i - 1] <= 2:
raise ValueError(f"All spatial dimensions must be > 2, got spatial dimensions {pred.shape[2:]}")
if pred.shape[1] != pred.ndim - 2:
raise ValueError(
f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim - 2}"
)

# first order gradient
first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)]

# spatial dimensions in a shape suited for broadcasting below
if self.normalize:
spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,))

diffusion = torch.tensor(0)
for dim_1, g in enumerate(first_order_gradient):
dim_1 += 2
if self.normalize:
g *= pred.shape[dim_1] / spatial_dims
diffusion = diffusion + g**2
kvttt marked this conversation as resolved.
Show resolved Hide resolved

if self.reduction == LossReduction.MEAN.value:
diffusion = torch.mean(diffusion) # the batch and channel average
elif self.reduction == LossReduction.SUM.value:
diffusion = torch.sum(diffusion) # sum over the batch and channel dims
elif self.reduction != LossReduction.NONE.value:
raise ValueError(f'Unsupported reduction: {self.reduction}, available options are ["mean", "sum", "none"].')

return diffusion
116 changes: 116 additions & 0 deletions tests/test_diffusion_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import unittest

import numpy as np
import torch
from parameterized import parameterized

from monai.losses.deform import DiffusionLoss

device = "cuda" if torch.cuda.is_available() else "cpu"

TEST_CASES = [
# all first partials are zero, so the diffusion loss is also zero
[{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0],
# all first partials are one, so the diffusion loss is also one
[{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 1.0],
# before expansion, the first partials are 2, 4, 6, so the diffusion loss is (2^2 + 4^2 + 6^2) / 3 = 18.67
[
{"normalize": False},
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
56.0 / 3.0,
],
# same as the previous case
[
{"normalize": False},
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
56.0 / 3.0,
],
# same as the previous case
[{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],
# we have shown in the demo notebook that
# diffusion loss is scale-invariant when the all axes have the same resolution
[
{"normalize": True},
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
56.0 / 3.0,
],
[
{"normalize": True},
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
56.0 / 3.0,
],
[{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 56.0 / 3.0],
# for the following case, consider the following 2D matrix:
# tensor([[[[0, 1, 2],
# [1, 2, 3],
# [2, 3, 4],
# [3, 4, 5],
# [4, 5, 6]],
# [[0, 1, 2],
# [1, 2, 3],
# [2, 3, 4],
# [3, 4, 5],
# [4, 5, 6]]]])
# the first partials wrt x are all ones, and so are the first partials wrt y
# the diffusion loss, when normalization is not applied, is 1^2 + 1^2 = 2
[{"normalize": False}, {"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)}, 2.0],
# consider the same matrix, this time with normalization applied, using the same notation as in the demo notebook,
# the coefficients to be divided out are (1, 5/3) for partials wrt x and (3/5, 1) for partials wrt y
# the diffusion loss is then (1/1)^2 + (1/(5/3))^2 + (1/(3/5))^2 + (1/1)^2 = (1 + 9/25 + 25/9 + 1) / 2 = 2.5689
[
{"normalize": True},
{"pred": torch.stack([torch.arange(i, i + 3) for i in range(5)]).expand(1, 2, 5, 3)},
(1.0 + 9.0 / 25.0 + 25.0 / 9.0 + 1.0) / 2.0,
],
]


class TestDiffusionLoss(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_shape(self, input_param, input_data, expected_val):
result = DiffusionLoss(**input_param).forward(**input_data)
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)

def test_ill_shape(self):
loss = DiffusionLoss()
# not in 3-d, 4-d, 5-d
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
loss.forward(torch.ones((1, 3), device=device))
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device))
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 2, 5, 5), device=device))
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 5, 2, 5)))
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 5, 5, 2)))

# number of vector components unequal to number of spatial dims
with self.assertRaisesRegex(ValueError, "Number of vector components"):
loss.forward(torch.ones((1, 2, 5, 5, 5)))
with self.assertRaisesRegex(ValueError, "Number of vector components"):
loss.forward(torch.ones((1, 2, 5, 5, 5)))

def test_ill_opts(self):
pred = torch.rand(1, 3, 5, 5, 5).to(device=device)
with self.assertRaisesRegex(ValueError, ""):
DiffusionLoss(reduction="unknown")(pred)
with self.assertRaisesRegex(ValueError, ""):
DiffusionLoss(reduction=None)(pred)


if __name__ == "__main__":
unittest.main()
Loading