diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 61dd959807..ba794af3eb 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -139,6 +139,11 @@ Reconstruction Losses .. autoclass:: JukeboxLoss :members: +`SURELoss` +~~~~~~~~~~ +.. autoclass:: SURELoss + :members: + Loss Wrappers ------------- diff --git a/docs/source/networks.rst b/docs/source/networks.rst index 8eada7933f..b59c8af5fc 100644 --- a/docs/source/networks.rst +++ b/docs/source/networks.rst @@ -408,6 +408,11 @@ Layers .. autoclass:: LLTM :members: +`ConjugateGradient` +~~~~~~~~~~~~~~~~~~~ +.. autoclass:: ConjugateGradient + :members: + `Utilities` ~~~~~~~~~~~ .. automodule:: monai.networks.layers.convutils diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index 4ebedb2084..e937b53fa4 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -41,5 +41,6 @@ from .spatial_mask import MaskedLoss from .spectral_loss import JukeboxLoss from .ssim_loss import SSIMLoss +from .sure_loss import SURELoss from .tversky import TverskyLoss from .unified_focal_loss import AsymmetricUnifiedFocalLoss diff --git a/monai/losses/sure_loss.py b/monai/losses/sure_loss.py new file mode 100644 index 0000000000..ebf25613a6 --- /dev/null +++ b/monai/losses/sure_loss.py @@ -0,0 +1,200 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Callable, Optional + +import torch +import torch.nn as nn +from torch.nn.modules.loss import _Loss + + +def complex_diff_abs_loss(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + First compute the difference in the complex domain, + then get the absolute value and take the mse + + Args: + x, y - B, 2, H, W real valued tensors representing complex numbers + or B,1,H,W complex valued tensors + Returns: + l2_loss - scalar + """ + if not x.is_complex(): + x = torch.view_as_complex(x.permute(0, 2, 3, 1).contiguous()) + if not y.is_complex(): + y = torch.view_as_complex(y.permute(0, 2, 3, 1).contiguous()) + + diff = torch.abs(x - y) + return nn.functional.mse_loss(diff, torch.zeros_like(diff), reduction="mean") + + +def sure_loss_function( + operator: Callable, + x: torch.Tensor, + y_pseudo_gt: torch.Tensor, + y_ref: Optional[torch.Tensor] = None, + eps: Optional[float] = -1.0, + perturb_noise: Optional[torch.Tensor] = None, + complex_input: Optional[bool] = False, +) -> torch.Tensor: + """ + Args: + operator (function): The operator function that takes in an input + tensor x and returns an output tensor y. We will use this to compute + the divergence. More specifically, we will perturb the input x by a + small amount and compute the divergence between the perturbed output + and the reference output + + x (torch.Tensor): The input tensor of shape (B, C, H, W) to the + operator. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape + (B, C, H, W) used to compute the L2 loss. For complex input, the shape is + (B, 2, H, W) aka C=2 real. For real input, the shape is (B, 1, H, W) + real. + + y_ref (torch.Tensor, optional): The reference output tensor of shape + (B, C, H, W) used to compute the divergence. Defaults to None. For + complex input, the shape is (B, 2, H, W) aka C=2 real. For real input, + the shape is (B, 1, H, W) real. + + eps (float, optional): The perturbation scalar. Set to -1 to set it + automatically estimated based on y_pseudo_gtk + + perturb_noise (torch.Tensor, optional): The noise vector of shape (B, C, H, W). + Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + complex_input(bool, optional): Whether the input is complex or not. + Defaults to False. + + Returns: + sure_loss (torch.Tensor): The SURE loss scalar. + """ + # perturb input + if perturb_noise is None: + perturb_noise = torch.randn_like(x) + if eps == -1.0: + eps = float(torch.abs(y_pseudo_gt.max())) / 1000 + # get y_ref if not provided + if y_ref is None: + y_ref = operator(x) + + # get perturbed output + x_perturbed = x + eps * perturb_noise + y_perturbed = operator(x_perturbed) + # divergence + divergence = torch.sum(1.0 / eps * torch.matmul(perturb_noise.permute(0, 1, 3, 2), y_perturbed - y_ref)) # type: ignore + # l2 loss between y_ref, y_pseudo_gt + if complex_input: + l2_loss = complex_diff_abs_loss(y_ref, y_pseudo_gt) + else: + # real input + l2_loss = nn.functional.mse_loss(y_ref, y_pseudo_gt, reduction="mean") + + # sure loss + sure_loss = l2_loss * divergence / (x.shape[0] * x.shape[2] * x.shape[3]) + return sure_loss + + +class SURELoss(_Loss): + """ + Calculate the Stein's Unbiased Risk Estimator (SURE) loss for a given operator. + + This is a differentiable loss function that can be used to train/guide an + operator (e.g. neural network), where the pseudo ground truth is available + but the reference ground truth is not. For example, in the MRI + reconstruction, the pseudo ground truth is the zero-filled reconstruction + and the reference ground truth is the fully sampled reconstruction. Often, + the reference ground truth is not available due to the lack of fully sampled + data. + + The original SURE loss is proposed in [1]. The SURE loss used for guiding + the diffusion model based MRI reconstruction is proposed in [2]. + + Reference + + [1] Stein, C.M.: Estimation of the mean of a multivariate normal distribution. Annals of Statistics + + [2] B. Ozturkler et al. SMRD: SURE-based Robust MRI Reconstruction with Diffusion Models. + (https://arxiv.org/pdf/2310.01799.pdf) + """ + + def __init__(self, perturb_noise: Optional[torch.Tensor] = None, eps: Optional[float] = None) -> None: + """ + Args: + perturb_noise (torch.Tensor, optional): The noise vector of shape + (B, C, H, W). Defaults to None. For complex input, the shape is (B, 2, H, W) aka C=2 real. + For real input, the shape is (B, 1, H, W) real. + + eps (float, optional): The perturbation scalar. Defaults to None. + """ + super().__init__() + self.perturb_noise = perturb_noise + self.eps = eps + + def forward( + self, + operator: Callable, + x: torch.Tensor, + y_pseudo_gt: torch.Tensor, + y_ref: Optional[torch.Tensor] = None, + complex_input: Optional[bool] = False, + ) -> torch.Tensor: + """ + Args: + operator (function): The operator function that takes in an input + tensor x and returns an output tensor y. We will use this to compute + the divergence. More specifically, we will perturb the input x by a + small amount and compute the divergence between the perturbed output + and the reference output + + x (torch.Tensor): The input tensor of shape (B, C, H, W) to the + operator. C=1 or 2: For complex input, the shape is (B, 2, H, W) aka + C=2 real. For real input, the shape is (B, 1, H, W) real. + + y_pseudo_gt (torch.Tensor): The pseudo ground truth tensor of shape + (B, C, H, W) used to compute the L2 loss. C=1 or 2: For complex + input, the shape is (B, 2, H, W) aka C=2 real. For real input, the + shape is (B, 1, H, W) real. + + y_ref (torch.Tensor, optional): The reference output tensor of the + same shape as y_pseudo_gt + + Returns: + sure_loss (torch.Tensor): The SURE loss scalar. + """ + + # check inputs shapes + if x.dim() != 4: + raise ValueError(f"Input tensor x should be 4D, got {x.dim()}.") + if y_pseudo_gt.dim() != 4: + raise ValueError(f"Input tensor y_pseudo_gt should be 4D, but got {y_pseudo_gt.dim()}.") + if y_ref is not None and y_ref.dim() != 4: + raise ValueError(f"Input tensor y_ref should be 4D, but got {y_ref.dim()}.") + if x.shape != y_pseudo_gt.shape: + raise ValueError( + f"Input tensor x and y_pseudo_gt should have the same shape, but got x shape {x.shape}, " + f"y_pseudo_gt shape {y_pseudo_gt.shape}." + ) + if y_ref is not None and y_pseudo_gt.shape != y_ref.shape: + raise ValueError( + f"Input tensor y_pseudo_gt and y_ref should have the same shape, but got y_pseudo_gt shape {y_pseudo_gt.shape}, " + f"y_ref shape {y_ref.shape}." + ) + + # compute loss + loss = sure_loss_function(operator, x, y_pseudo_gt, y_ref, self.eps, self.perturb_noise, complex_input) + + return loss diff --git a/monai/networks/layers/__init__.py b/monai/networks/layers/__init__.py index d61ed57f7f..3a6e4aa554 100644 --- a/monai/networks/layers/__init__.py +++ b/monai/networks/layers/__init__.py @@ -11,6 +11,7 @@ from __future__ import annotations +from .conjugate_gradient import ConjugateGradient from .convutils import calculate_out_shape, gaussian_1d, polyval, same_padding, stride_minus_kernel_padding from .drop_path import DropPath from .factories import Act, Conv, Dropout, LayerFactory, Norm, Pad, Pool, split_args diff --git a/monai/networks/layers/conjugate_gradient.py b/monai/networks/layers/conjugate_gradient.py new file mode 100644 index 0000000000..93a45930d7 --- /dev/null +++ b/monai/networks/layers/conjugate_gradient.py @@ -0,0 +1,112 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Callable + +import torch +from torch import nn + + +def _zdot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: + """ + Complex dot product between tensors x1 and x2: sum(x1.*x2) + """ + if torch.is_complex(x1): + assert torch.is_complex(x2), "x1 and x2 must both be complex" + return torch.sum(x1.conj() * x2) + else: + return torch.sum(x1 * x2) + + +def _zdot_single(x: torch.Tensor) -> torch.Tensor: + """ + Complex dot product between tensor x and itself + """ + res = _zdot(x, x) + if torch.is_complex(res): + return res.real + else: + return res + + +class ConjugateGradient(nn.Module): + """ + Congugate Gradient (CG) solver for linear systems Ax = y. + + For linear_op that is positive definite and self-adjoint, CG is + guaranteed to converge CG is often used to solve linear systems of the form + Ax = y, where A is too large to store explicitly, but can be computed via a + linear operator. + + As a result, here we won't set A explicitly as a matrix, but rather as a + linear operator. For example, A could be a FFT/IFFT operation + """ + + def __init__(self, linear_op: Callable, num_iter: int): + """ + Args: + linear_op: Linear operator + num_iter: Number of iterations to run CG + """ + super().__init__() + + self.linear_op = linear_op + self.num_iter = num_iter + + def update( + self, x: torch.Tensor, p: torch.Tensor, r: torch.Tensor, rsold: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + perform one iteration of the CG method. It takes the current solution x, + the current search direction p, the current residual r, and the old + residual norm rsold as inputs. Then it computes the new solution, search + direction, residual, and residual norm, and returns them. + """ + + dy = self.linear_op(p) + p_dot_dy = _zdot(p, dy) + alpha = rsold / p_dot_dy + x = x + alpha * p + r = r - alpha * dy + rsnew = _zdot_single(r) + beta = rsnew / rsold + rsold = rsnew + p = beta * p + r + return x, p, r, rsold + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """ + run conjugate gradient for num_iter iterations to solve Ax = y + + Args: + x: tensor (real or complex); Initial guess for linear system Ax = y. + The size of x should be applicable to the linear operator. For + example, if the linear operator is FFT, then x is HCHW; if the + linear operator is a matrix multiplication, then x is a vector + + y: tensor (real or complex); Measurement. Same size as x + + Returns: + x: Solution to Ax = y + """ + # Compute residual + r = y - self.linear_op(x) + rsold = _zdot_single(r) + p = r + + # Update + for _i in range(self.num_iter): + x, p, r, rsold = self.update(x, p, r, rsold) + if rsold < 1e-10: + break + return x diff --git a/tests/test_conjugate_gradient.py b/tests/test_conjugate_gradient.py new file mode 100644 index 0000000000..239dbe3ecd --- /dev/null +++ b/tests/test_conjugate_gradient.py @@ -0,0 +1,55 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.networks.layers import ConjugateGradient + + +class TestConjugateGradient(unittest.TestCase): + def test_real_valued_inverse(self): + """Test ConjugateGradient with real-valued input: when the input is real + value, the output should be the inverse of the matrix.""" + a_dim = 3 + a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.float) + + def a_op(x): + return a_mat @ x + + cg_solver = ConjugateGradient(a_op, num_iter=100) + # define the measurement + y = torch.tensor([1, 2, 3], dtype=torch.float) + # solve for x + x = cg_solver(torch.zeros(a_dim), y) + x_ref = torch.linalg.solve(a_mat, y) + # assert torch.allclose(x, x_ref, atol=1e-6), 'CG solver failed to converge to reference solution' + self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + + def test_complex_valued_inverse(self): + a_dim = 3 + a_mat = torch.tensor([[1, 2, 3], [2, 1, 2], [3, 2, 1]], dtype=torch.complex64) + + def a_op(x): + return a_mat @ x + + cg_solver = ConjugateGradient(a_op, num_iter=100) + y = torch.tensor([1, 2, 3], dtype=torch.complex64) + x = cg_solver(torch.zeros(a_dim, dtype=torch.complex64), y) + x_ref = torch.linalg.solve(a_mat, y) + self.assertTrue(torch.allclose(x, x_ref, atol=1e-6)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_sure_loss.py b/tests/test_sure_loss.py new file mode 100644 index 0000000000..945da657bf --- /dev/null +++ b/tests/test_sure_loss.py @@ -0,0 +1,71 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import torch + +from monai.losses import SURELoss + + +class TestSURELoss(unittest.TestCase): + def test_real_value(self): + """Test SURELoss with real-valued input: when the input is real value, the loss should be 0.0.""" + sure_loss_real = SURELoss(perturb_noise=torch.zeros(2, 1, 128, 128), eps=0.1) + + def operator(x): + return x + + y_pseudo_gt = torch.randn(2, 1, 128, 128) + x = torch.randn(2, 1, 128, 128) + loss = sure_loss_real(operator, x, y_pseudo_gt, complex_input=False) + self.assertAlmostEqual(loss.item(), 0.0) + + def test_complex_value(self): + """Test SURELoss with complex-valued input: when the input is complex value, the loss should be 0.0.""" + + def operator(x): + return x + + sure_loss_complex = SURELoss(perturb_noise=torch.zeros(2, 2, 128, 128), eps=0.1) + y_pseudo_gt = torch.randn(2, 2, 128, 128) + x = torch.randn(2, 2, 128, 128) + loss = sure_loss_complex(operator, x, y_pseudo_gt, complex_input=True) + self.assertAlmostEqual(loss.item(), 0.0) + + def test_complex_general_input(self): + """Test SURELoss with complex-valued input: when the input is general complex value, the loss should be 0.0.""" + + def operator(x): + return x + + perturb_noise_real = torch.randn(2, 1, 128, 128) + perturb_noise_complex = torch.zeros(2, 2, 128, 128) + perturb_noise_complex[:, 0, :, :] = perturb_noise_real.squeeze() + y_pseudo_gt_real = torch.randn(2, 1, 128, 128) + y_pseudo_gt_complex = torch.zeros(2, 2, 128, 128) + y_pseudo_gt_complex[:, 0, :, :] = y_pseudo_gt_real.squeeze() + x_real = torch.randn(2, 1, 128, 128) + x_complex = torch.zeros(2, 2, 128, 128) + x_complex[:, 0, :, :] = x_real.squeeze() + + sure_loss_real = SURELoss(perturb_noise=perturb_noise_real, eps=0.1) + sure_loss_complex = SURELoss(perturb_noise=perturb_noise_complex, eps=0.1) + + loss_real = sure_loss_real(operator, x_real, y_pseudo_gt_real, complex_input=False) + loss_complex = sure_loss_complex(operator, x_complex, y_pseudo_gt_complex, complex_input=True) + self.assertAlmostEqual(loss_real.item(), loss_complex.abs().item(), places=6) + + +if __name__ == "__main__": + unittest.main()