Skip to content

Commit

Permalink
Change MOLLI from (a,b,T1) to (a,c=b/a,T1) (#406)
Browse files Browse the repository at this point in the history
  • Loading branch information
ckolbPTB authored Sep 19, 2024
1 parent 567089a commit f86717d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
21 changes: 15 additions & 6 deletions src/mrpro/operators/models/MOLLI.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,18 @@


class MOLLI(SignalModel[torch.Tensor, torch.Tensor, torch.Tensor]):
"""Signal model for Modified Look-Locker inversion recovery (MOLLI)."""
"""Signal model for Modified Look-Locker inversion recovery (MOLLI).
This model describes
:math:`M_z(t) = a(1 - c)e^{(-t / T1^*)}` with :math:`T1^* = T1 / (c - 1)`.
This is a small modification from the original MOLLI signal model [MESS2004]_:
:math:`M_z(t) = a - be^{(-t / T1^*)}` with :math:`T1^* = T1 / (b/a - 1)`.
.. [MESS2004] Messroghli DR, Radjenovic A, Kozerke S, Higgins DM, Sivananthan MU, Ridgway JP (2004) Modified
look-locker inversion recovery (MOLLI) for high-resolution T 1 mapping of the heart. MRM, 52(1).
https://doi.org/10.1002/mrm.20110
"""

def __init__(self, ti: float | torch.Tensor):
"""Initialize MOLLI signal model for T1 mapping.
Expand All @@ -21,16 +32,16 @@ def __init__(self, ti: float | torch.Tensor):
ti = torch.as_tensor(ti)
self.ti = torch.nn.Parameter(ti, requires_grad=ti.requires_grad)

def forward(self, a: torch.Tensor, b: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]:
def forward(self, a: torch.Tensor, c: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]:
"""Apply MOLLI signal model.
Parameters
----------
a
parameter a in MOLLI signal model
with shape (... other, coils, z, y, x)
b
parameter b in MOLLI signal model
c
parameter c = b/a in MOLLI signal model
with shape (... other, coils, z, y, x)
t1
longitudinal relaxation time T1
Expand All @@ -41,7 +52,5 @@ def forward(self, a: torch.Tensor, b: torch.Tensor, t1: torch.Tensor) -> tuple[t
signal with shape (time ... other, coils, z, y, x)
"""
ti = self.expand_tensor_dim(self.ti, a.ndim - (self.ti.ndim - 1)) # -1 for time
c = b / torch.where(a == 0, 1e-10, a)
t1 = torch.where(t1 == 0, t1 + 1e-10, t1)
signal = a * (1 - c * torch.exp(ti / t1 * (1 - c)))
return (signal,)
26 changes: 13 additions & 13 deletions tests/operators/models/test_molli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,35 @@
import pytest
import torch
from mrpro.operators.models import MOLLI
from tests import RandomGenerator
from tests.operators.models.conftest import SHAPE_VARIATIONS_SIGNAL_MODELS, create_parameter_tensor_tuples


@pytest.mark.parametrize(
('ti', 'result'),
[
(0, 'a-b'), # short ti
(0, 'a(1-c)'), # short ti
(20, 'a'), # long ti
],
)
def test_molli(ti, result):
"""Test for MOLLI.
Checking that idata output tensor at ti=0 is close to a. Checking
that idata output tensor at large ti is close to a-b.
that idata output tensor at large ti is close to a(1-c).
"""
# Generate qdata tensor, not random as a<b is necessary for t1_star to be >= 0
other, coils, z, y, x = 10, 5, 100, 100, 100
a = torch.ones((other, coils, z, y, x)) * 2
b = torch.ones((other, coils, z, y, x)) * 5
t1 = torch.ones((other, coils, z, y, x)) * 2
a, t1 = create_parameter_tensor_tuples()
# c>2 is necessary for t1_star to be >= 0 and to ensure t1_star < t1
random_generator = RandomGenerator(seed=0)
c = random_generator.float32_tensor(size=a.shape, low=2.0, high=4.0)

# Generate signal model and torch tensor for comparison
model = MOLLI(ti)
(image,) = model.forward(a, b, t1)
(image,) = model.forward(a, c, t1)

# Assert closeness to a-b for large ti
if result == 'a-b':
torch.testing.assert_close(image[0, ...], a - b)
# Assert closeness to a(1-c) for large ti
if result == 'a(1-c)':
torch.testing.assert_close(image[0, ...], a * (1 - c))
# Assert closeness to a for ti=0
elif result == 'a':
torch.testing.assert_close(image[0, ...], a)
Expand All @@ -42,6 +42,6 @@ def test_molli_shape(parameter_shape, contrast_dim_shape, signal_shape):
"""Test correct signal shapes."""
(ti,) = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=1)
model_op = MOLLI(ti)
a, b, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3)
(signal,) = model_op.forward(a, b, t1)
a, c, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3)
(signal,) = model_op.forward(a, c, t1)
assert signal.shape == signal_shape

0 comments on commit f86717d

Please sign in to comment.