From 760337564188812c5188142ad091b1869d46e053 Mon Sep 17 00:00:00 2001 From: rkcatarina Date: Wed, 13 Nov 2024 13:41:21 +0100 Subject: [PATCH] add mse --- README.md | 4 +- examples/qmri_sg_challenge_2024_t1.py | 6 +-- examples/qmri_sg_challenge_2024_t2_star.py | 6 +-- examples/t1_mapping_with_grad_acq.py | 6 +-- pyproject.toml | 1 + src/mrpro/operators/functionals/L1Norm.py | 2 + .../operators/functionals/L1NormViewAsReal.py | 2 + src/mrpro/operators/functionals/MSE.py | 50 +++++++++++++++++++ src/mrpro/operators/functionals/__init__.py | 3 +- tests/operators/functionals/__init__.py | 2 + .../operators/functionals/test_functionals.py | 15 +++++- 11 files changed, 84 insertions(+), 13 deletions(-) create mode 100644 src/mrpro/operators/functionals/MSE.py diff --git a/README.md b/README.md index 081865541..4122aa8e8 100644 --- a/README.md +++ b/README.md @@ -52,8 +52,8 @@ Quantitative parameter maps can be obtained by creating a functional to be minim # Define signal model model = MagnitudeOp() @ InversionRecovery(ti=idata_multi_ti.header.ti) # Define loss function and combine with signal model -l2norm_squared = L2NormSquared(idata_multi_ti.data.abs(), divide_by_n=True) -functional = l2norm_squared @ model +mse = MSE(idata_multi_ti.data.abs()) +functional = mse @ model [...] # Run optimization params_result = adam(functional, [m0_start, t1_start], max_iter=max_iter, lr=lr) diff --git a/examples/qmri_sg_challenge_2024_t1.py b/examples/qmri_sg_challenge_2024_t1.py index eea01c3ff..ee146cb20 100644 --- a/examples/qmri_sg_challenge_2024_t1.py +++ b/examples/qmri_sg_challenge_2024_t1.py @@ -16,7 +16,7 @@ from mrpro.algorithms.optimizers import adam from mrpro.data import IData from mrpro.operators import MagnitudeOp -from mrpro.operators.functionals import L2NormSquared +from mrpro.operators.functionals import MSE from mrpro.operators.models import InversionRecovery # %% [markdown] @@ -71,14 +71,14 @@ # As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal # model $q$. # %% -l2norm_squared = L2NormSquared(idata_multi_ti.data.abs(), divide_by_n=True) +mse = MSE(idata_multi_ti.data.abs()) # %% [markdown] # Now we can simply combine the two into a functional to solve # # $ \min_{M_0, T1} || |q(M_0, T1, TI)| - x||_2^2$ # %% -functional = l2norm_squared @ model +functional = mse @ model # %% [markdown] # ### Starting values for the fit diff --git a/examples/qmri_sg_challenge_2024_t2_star.py b/examples/qmri_sg_challenge_2024_t2_star.py index 2c8473f0b..e7e28372f 100644 --- a/examples/qmri_sg_challenge_2024_t2_star.py +++ b/examples/qmri_sg_challenge_2024_t2_star.py @@ -15,7 +15,7 @@ from mpl_toolkits.axes_grid1 import make_axes_locatable # type: ignore [import-untyped] from mrpro.algorithms.optimizers import adam from mrpro.data import IData -from mrpro.operators.functionals import L2NormSquared +from mrpro.operators.functionals import MSE from mrpro.operators.models import MonoExponentialDecay # %% [markdown] @@ -78,14 +78,14 @@ # As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal # model $q$. # %% -l2norm_squared = L2NormSquared(idata_multi_te.data, divide_by_n=True) +mse = MSE(idata_multi_te.data) # %% [markdown] # Now we can simply combine the two into a functional which will then solve # # $ \min_{M_0, T2^*} ||q(M_0, T2^*, TE) - x||_2^2$ # %% -functional = l2norm_squared @ model +functional = mse @ model # %% [markdown] # ### Carry out fit diff --git a/examples/t1_mapping_with_grad_acq.py b/examples/t1_mapping_with_grad_acq.py index 1cdc43ebb..43b798874 100644 --- a/examples/t1_mapping_with_grad_acq.py +++ b/examples/t1_mapping_with_grad_acq.py @@ -16,7 +16,7 @@ from mrpro.data import KData from mrpro.data.traj_calculators import KTrajectoryIsmrmrd from mrpro.operators import ConstraintsOp, MagnitudeOp -from mrpro.operators.functionals import L2NormSquared +from mrpro.operators.functionals import MSE from mrpro.operators.models import TransientSteadyStateWithPreparation from mrpro.utils import split_idx @@ -176,14 +176,14 @@ # As a loss function for the optimizer, we calculate the squared L2 norm between the image data $x$ and our signal # model $q$. # %% -l2norm_squared_loss = L2NormSquared(img_rss_dynamic, divide_by_n=True) +mse_loss = MSE(img_rss_dynamic) # %% [markdown] # Now we can simply combine the loss function, the signal model and the constraints to solve # # $$ \min_{M_0, T_1, \alpha} || |q(M_0, T_1, \alpha)| - x||_2^2$$ # %% -functional = l2norm_squared_loss @ magnitude_model_op @ constraints_op +functional = mse_loss @ magnitude_model_op @ constraints_op # %% [markdown] # ### Carry out fit diff --git a/pyproject.toml b/pyproject.toml index 606893d7b..21ec87356 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,6 +34,7 @@ authors = [ { name = "Johannes Hammacher", email = "johannnes.hammacher@ptb.de" }, { name = "Stefan Martin", email = "stefan.martin@ptb.de" }, { name = "Andreas Kofler", email = "andreas.kofler@ptb.de" }, + { name = "Catarina Redshaw Kranich", email = "catarina.redshaw-kranich@ptb.de" }, ] classifiers = [ "License :: OSI Approved :: Apache Software License", diff --git a/src/mrpro/operators/functionals/L1Norm.py b/src/mrpro/operators/functionals/L1Norm.py index 20380a9eb..29f7b753c 100644 --- a/src/mrpro/operators/functionals/L1Norm.py +++ b/src/mrpro/operators/functionals/L1Norm.py @@ -13,6 +13,8 @@ class L1Norm(ElementaryProximableFunctional): where W is a either a scalar or tensor that corresponds to a (block-) diagonal operator that is applied to the input. + In most cases, consider setting divide_by_n to true to be independent of input size. + The norm of the vector is computed along the dimensions given at initialization. """ diff --git a/src/mrpro/operators/functionals/L1NormViewAsReal.py b/src/mrpro/operators/functionals/L1NormViewAsReal.py index d8aba9dac..e4227c70b 100644 --- a/src/mrpro/operators/functionals/L1NormViewAsReal.py +++ b/src/mrpro/operators/functionals/L1NormViewAsReal.py @@ -15,6 +15,8 @@ class L1NormViewAsReal(ElementaryProximableFunctional): If the parameter `weight` is real-valued, :math:`W_r` and :math:`W_i` are both set to `weight`. If it is complex-valued, :math:`W_r` and :math:`W_I` are set to the real and imaginary part, respectively. + In most cases, consider setting divide_by_n to true to be independent of input size. + The norm of the vector is computed along the dimensions set at initialization. """ diff --git a/src/mrpro/operators/functionals/MSE.py b/src/mrpro/operators/functionals/MSE.py new file mode 100644 index 000000000..08549c4c1 --- /dev/null +++ b/src/mrpro/operators/functionals/MSE.py @@ -0,0 +1,50 @@ +"""MSE-Functional.""" + +from collections.abc import Sequence + +import torch + +from mrpro.operators.functionals.L2NormSquared import L2NormSquared + + +class MSE(L2NormSquared): + r"""Functional class for the mean square error. + + This makes use of the functional L2NormSquared. + """ + + def __init__( + self, + weight: torch.Tensor | complex = 1.0, + target: torch.Tensor | None | complex = None, + dim: int | Sequence[int] | None = None, + divide_by_n: bool = True, + keepdim: bool = False, + ) -> None: + r"""Initialize a Functional. + + We assume that functionals are given in the form + :math:`f(x) = \phi ( weight ( x - target))` + for some functional :math:`\phi`. + + Parameters + ---------- + functional + functional to be employed + weight + weight parameter (see above) + target + target element - often data tensor (see above) + dim + dimension(s) over which functional is reduced. + All other dimensions of `weight ( x - target)` will be treated as batch dimensions. + divide_by_n + if true, the result is scaled by the number of elements of the dimensions index by `dim` in + the tensor `weight ( x - target)`. If true, the functional is thus calculated as the mean, + else the sum. + keepdim + if true, the dimension(s) of the input indexed by dim are maintained and collapsed to singeltons, + else they are removed from the result. + + """ + super().__init__(weight=weight, target=target, dim=dim, divide_by_n=divide_by_n, keepdim=keepdim) diff --git a/src/mrpro/operators/functionals/__init__.py b/src/mrpro/operators/functionals/__init__.py index 48eb6083a..3fe3455d7 100644 --- a/src/mrpro/operators/functionals/__init__.py +++ b/src/mrpro/operators/functionals/__init__.py @@ -1,5 +1,6 @@ from mrpro.operators.functionals.L1Norm import L1Norm from mrpro.operators.functionals.L1NormViewAsReal import L1NormViewAsReal from mrpro.operators.functionals.L2NormSquared import L2NormSquared +from mrpro.operators.functionals.MSE import MSE from mrpro.operators.functionals.ZeroFunctional import ZeroFunctional -__all__ = ["L1Norm", "L1NormViewAsReal", "L2NormSquared", "MSEDataDiscrepancy", "ZeroFunctional"] +__all__ = ["L1Norm", "L1NormViewAsReal", "L2NormSquared", "MSE", "ZeroFunctional"] diff --git a/tests/operators/functionals/__init__.py b/tests/operators/functionals/__init__.py index 19b5d0454..878750b62 100644 --- a/tests/operators/functionals/__init__.py +++ b/tests/operators/functionals/__init__.py @@ -1,3 +1,5 @@ from mrpro.operators.functionals.L1NormViewAsReal import L1NormViewAsReal +from mrpro.operators.functionals.L1Norm import L1Norm from mrpro.operators.functionals.L2NormSquared import L2NormSquared +from mrpro.operators.functionals.MSE import MSE from mrpro.operators.functionals.ZeroFunctional import ZeroFunctional diff --git a/tests/operators/functionals/test_functionals.py b/tests/operators/functionals/test_functionals.py index e0b9efd65..150e991b9 100644 --- a/tests/operators/functionals/test_functionals.py +++ b/tests/operators/functionals/test_functionals.py @@ -4,7 +4,7 @@ import pytest import torch from mrpro.operators.Functional import ElementaryFunctional, ElementaryProximableFunctional -from mrpro.operators.functionals import L1Norm, L1NormViewAsReal, L2NormSquared, ZeroFunctional +from mrpro.operators.functionals import MSE, L1Norm, L1NormViewAsReal, L2NormSquared, ZeroFunctional from tests import RandomGenerator from tests.operators.functionals.conftest import ( @@ -296,6 +296,19 @@ class NumericCase(TypedDict): [[[-2.983529, -1.943529, -1.049412], [-0.108235, 1.468235, 1.971765]]] ), }, + 'MSE': { + # Generated with ODL + 'functional': MSE, + 'x': torch.tensor([[[-3.0, -2.0, -1.0], [0.0, 1.0, 2.0]]]), + 'weight': 2.0, + 'target': torch.tensor([[[0.340, 0.130, 0.230], [0.230, -1.120, -0.190]]]), + 'sigma': 0.5, + 'fx_expected': torch.tensor(17.6992), + 'prox_expected': torch.tensor([[[-1.6640, -1.1480, -0.5080], [0.0920, 0.1520, 1.1240]]]), + 'prox_convex_conj_expected': torch.tensor( + [[[-2.305455, -1.501818, -0.810909], [-0.083636, 1.134545, 1.523636]]] + ), + }, }