Skip to content

Commit

Permalink
add mse
Browse files Browse the repository at this point in the history
  • Loading branch information
rkcatarina committed Nov 13, 2024
1 parent 475a7cc commit 7603375
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 13 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ Quantitative parameter maps can be obtained by creating a functional to be minim
# Define signal model
model = MagnitudeOp() @ InversionRecovery(ti=idata_multi_ti.header.ti)
# Define loss function and combine with signal model
l2norm_squared = L2NormSquared(idata_multi_ti.data.abs(), divide_by_n=True)
functional = l2norm_squared @ model
mse = MSE(idata_multi_ti.data.abs())
functional = mse @ model
[...]
# Run optimization
params_result = adam(functional, [m0_start, t1_start], max_iter=max_iter, lr=lr)
Expand Down
6 changes: 3 additions & 3 deletions examples/qmri_sg_challenge_2024_t1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mrpro.algorithms.optimizers import adam
from mrpro.data import IData
from mrpro.operators import MagnitudeOp
from mrpro.operators.functionals import L2NormSquared
from mrpro.operators.functionals import MSE
from mrpro.operators.models import InversionRecovery

# %% [markdown]
Expand Down Expand Up @@ -71,14 +71,14 @@
# As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal
# model $q$.
# %%
l2norm_squared = L2NormSquared(idata_multi_ti.data.abs(), divide_by_n=True)
mse = MSE(idata_multi_ti.data.abs())

# %% [markdown]
# Now we can simply combine the two into a functional to solve
#
# $ \min_{M_0, T1} || |q(M_0, T1, TI)| - x||_2^2$
# %%
functional = l2norm_squared @ model
functional = mse @ model

# %% [markdown]
# ### Starting values for the fit
Expand Down
6 changes: 3 additions & 3 deletions examples/qmri_sg_challenge_2024_t2_star.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mpl_toolkits.axes_grid1 import make_axes_locatable # type: ignore [import-untyped]
from mrpro.algorithms.optimizers import adam
from mrpro.data import IData
from mrpro.operators.functionals import L2NormSquared
from mrpro.operators.functionals import MSE
from mrpro.operators.models import MonoExponentialDecay

# %% [markdown]
Expand Down Expand Up @@ -78,14 +78,14 @@
# As a loss function for the optimizer, we calculate the mean-squared error between the image data $x$ and our signal
# model $q$.
# %%
l2norm_squared = L2NormSquared(idata_multi_te.data, divide_by_n=True)
mse = MSE(idata_multi_te.data)

# %% [markdown]
# Now we can simply combine the two into a functional which will then solve
#
# $ \min_{M_0, T2^*} ||q(M_0, T2^*, TE) - x||_2^2$
# %%
functional = l2norm_squared @ model
functional = mse @ model

# %% [markdown]
# ### Carry out fit
Expand Down
6 changes: 3 additions & 3 deletions examples/t1_mapping_with_grad_acq.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from mrpro.data import KData
from mrpro.data.traj_calculators import KTrajectoryIsmrmrd
from mrpro.operators import ConstraintsOp, MagnitudeOp
from mrpro.operators.functionals import L2NormSquared
from mrpro.operators.functionals import MSE
from mrpro.operators.models import TransientSteadyStateWithPreparation
from mrpro.utils import split_idx

Expand Down Expand Up @@ -176,14 +176,14 @@
# As a loss function for the optimizer, we calculate the squared L2 norm between the image data $x$ and our signal
# model $q$.
# %%
l2norm_squared_loss = L2NormSquared(img_rss_dynamic, divide_by_n=True)
mse_loss = MSE(img_rss_dynamic)

# %% [markdown]
# Now we can simply combine the loss function, the signal model and the constraints to solve
#
# $$ \min_{M_0, T_1, \alpha} || |q(M_0, T_1, \alpha)| - x||_2^2$$
# %%
functional = l2norm_squared_loss @ magnitude_model_op @ constraints_op
functional = mse_loss @ magnitude_model_op @ constraints_op

# %% [markdown]
# ### Carry out fit
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ authors = [
{ name = "Johannes Hammacher", email = "johannnes.hammacher@ptb.de" },
{ name = "Stefan Martin", email = "stefan.martin@ptb.de" },
{ name = "Andreas Kofler", email = "andreas.kofler@ptb.de" },
{ name = "Catarina Redshaw Kranich", email = "catarina.redshaw-kranich@ptb.de" },
]
classifiers = [
"License :: OSI Approved :: Apache Software License",
Expand Down
2 changes: 2 additions & 0 deletions src/mrpro/operators/functionals/L1Norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class L1Norm(ElementaryProximableFunctional):
where W is a either a scalar or tensor that corresponds to a (block-) diagonal operator
that is applied to the input.
In most cases, consider setting divide_by_n to true to be independent of input size.
The norm of the vector is computed along the dimensions given at initialization.
"""

Expand Down
2 changes: 2 additions & 0 deletions src/mrpro/operators/functionals/L1NormViewAsReal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ class L1NormViewAsReal(ElementaryProximableFunctional):
If the parameter `weight` is real-valued, :math:`W_r` and :math:`W_i` are both set to `weight`.
If it is complex-valued, :math:`W_r` and :math:`W_I` are set to the real and imaginary part, respectively.
In most cases, consider setting divide_by_n to true to be independent of input size.
The norm of the vector is computed along the dimensions set at initialization.
"""

Expand Down
50 changes: 50 additions & 0 deletions src/mrpro/operators/functionals/MSE.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""MSE-Functional."""

from collections.abc import Sequence

import torch

from mrpro.operators.functionals.L2NormSquared import L2NormSquared


class MSE(L2NormSquared):
r"""Functional class for the mean square error.
This makes use of the functional L2NormSquared.
"""

def __init__(
self,
weight: torch.Tensor | complex = 1.0,
target: torch.Tensor | None | complex = None,
dim: int | Sequence[int] | None = None,
divide_by_n: bool = True,
keepdim: bool = False,
) -> None:
r"""Initialize a Functional.
We assume that functionals are given in the form
:math:`f(x) = \phi ( weight ( x - target))`
for some functional :math:`\phi`.
Parameters
----------
functional
functional to be employed
weight
weight parameter (see above)
target
target element - often data tensor (see above)
dim
dimension(s) over which functional is reduced.
All other dimensions of `weight ( x - target)` will be treated as batch dimensions.
divide_by_n
if true, the result is scaled by the number of elements of the dimensions index by `dim` in
the tensor `weight ( x - target)`. If true, the functional is thus calculated as the mean,
else the sum.
keepdim
if true, the dimension(s) of the input indexed by dim are maintained and collapsed to singeltons,
else they are removed from the result.
"""
super().__init__(weight=weight, target=target, dim=dim, divide_by_n=divide_by_n, keepdim=keepdim)
3 changes: 2 additions & 1 deletion src/mrpro/operators/functionals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from mrpro.operators.functionals.L1Norm import L1Norm
from mrpro.operators.functionals.L1NormViewAsReal import L1NormViewAsReal
from mrpro.operators.functionals.L2NormSquared import L2NormSquared
from mrpro.operators.functionals.MSE import MSE
from mrpro.operators.functionals.ZeroFunctional import ZeroFunctional
__all__ = ["L1Norm", "L1NormViewAsReal", "L2NormSquared", "MSEDataDiscrepancy", "ZeroFunctional"]
__all__ = ["L1Norm", "L1NormViewAsReal", "L2NormSquared", "MSE", "ZeroFunctional"]
2 changes: 2 additions & 0 deletions tests/operators/functionals/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from mrpro.operators.functionals.L1NormViewAsReal import L1NormViewAsReal
from mrpro.operators.functionals.L1Norm import L1Norm
from mrpro.operators.functionals.L2NormSquared import L2NormSquared
from mrpro.operators.functionals.MSE import MSE
from mrpro.operators.functionals.ZeroFunctional import ZeroFunctional
15 changes: 14 additions & 1 deletion tests/operators/functionals/test_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
import torch
from mrpro.operators.Functional import ElementaryFunctional, ElementaryProximableFunctional
from mrpro.operators.functionals import L1Norm, L1NormViewAsReal, L2NormSquared, ZeroFunctional
from mrpro.operators.functionals import MSE, L1Norm, L1NormViewAsReal, L2NormSquared, ZeroFunctional

from tests import RandomGenerator
from tests.operators.functionals.conftest import (
Expand Down Expand Up @@ -296,6 +296,19 @@ class NumericCase(TypedDict):
[[[-2.983529, -1.943529, -1.049412], [-0.108235, 1.468235, 1.971765]]]
),
},
'MSE': {
# Generated with ODL
'functional': MSE,
'x': torch.tensor([[[-3.0, -2.0, -1.0], [0.0, 1.0, 2.0]]]),
'weight': 2.0,
'target': torch.tensor([[[0.340, 0.130, 0.230], [0.230, -1.120, -0.190]]]),
'sigma': 0.5,
'fx_expected': torch.tensor(17.6992),
'prox_expected': torch.tensor([[[-1.6640, -1.1480, -0.5080], [0.0920, 0.1520, 1.1240]]]),
'prox_convex_conj_expected': torch.tensor(
[[[-2.305455, -1.501818, -0.810909], [-0.083636, 1.134545, 1.523636]]]
),
},
}


Expand Down

0 comments on commit 7603375

Please sign in to comment.