Skip to content

Commit

Permalink
Merge branch 'main' into improvedoc
Browse files Browse the repository at this point in the history
  • Loading branch information
fzimmermann89 committed Jan 15, 2025
2 parents 4e5e62b + aff63fb commit 9d7dc8d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 39 deletions.
4 changes: 2 additions & 2 deletions src/mrpro/operators/models/WASABI.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,13 @@ def forward(
signal with shape `(offsets *other, coils, z, y, x)`
"""
offsets = unsqueeze_right(self.offsets, b0_shift.ndim - (self.offsets.ndim - 1)) # -1 for offset
delta_b = offsets - b0_shift
delta_b0 = offsets - b0_shift
b1 = self.b1_nominal * relative_b1

signal = (
c
- d
* (torch.pi * b1 * self.gamma * self.rf_duration) ** 2
* torch.sinc(self.rf_duration * torch.sqrt((b1 * self.gamma) ** 2 + delta_b**2)) ** 2
* torch.sinc(self.rf_duration * torch.sqrt((b1 * self.gamma) ** 2 + delta_b0**2)) ** 2
)
return (signal,)
10 changes: 5 additions & 5 deletions src/mrpro/operators/models/WASABITI.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(

if recovery_time.shape != offsets.shape:
raise ValueError(
f'Shape of trec ({recovery_time.shape}) and offsets ({offsets.shape}) needs to be the same.'
f'Shape of recovery_time ({recovery_time.shape}) and offsets ({offsets.shape}) needs to be the same.'
)

# nn.Parameters allow for grad calculation
Expand Down Expand Up @@ -83,16 +83,16 @@ def forward(self, b0_shift: torch.Tensor, relative_b1: torch.Tensor, t1: torch.T
"""
delta_ndim = b0_shift.ndim - (self.offsets.ndim - 1) # -1 for offset
offsets = unsqueeze_right(self.offsets, delta_ndim)
trec = unsqueeze_right(self.recovery_time, delta_ndim)
recovery_time = unsqueeze_right(self.recovery_time, delta_ndim)

b1 = self.b1_nominal * relative_b1
da = offsets - b0_shift
mz_initial = 1.0 - torch.exp(-trec / t1)
delta_b0 = offsets - b0_shift
mz_initial = 1.0 - torch.exp(-recovery_time / t1)

signal = mz_initial * (
1
- 2
* (torch.pi * b1 * self.gamma * self.rf_duration) ** 2
* torch.sinc(self.rf_duration * torch.sqrt((b1 * self.gamma) ** 2 + da**2)) ** 2
* torch.sinc(self.rf_duration * torch.sqrt((b1 * self.gamma) ** 2 + delta_b0**2)) ** 2
)
return (signal,)
64 changes: 32 additions & 32 deletions tests/operators/models/test_wasabiti.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,80 +8,80 @@


def create_data(
offset_max=500, n_offsets=101, b0_shift=0, rb1=1.0, t1=1.0
offset_max=500, n_offsets=101, b0_shift=0, relative_b1=1.0, t1=1.0
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
offsets = torch.linspace(-offset_max, offset_max, n_offsets)
return offsets, torch.Tensor([b0_shift]), torch.Tensor([rb1]), torch.Tensor([t1])
return offsets, torch.Tensor([b0_shift]), torch.Tensor([relative_b1]), torch.Tensor([t1])


def test_WASABITI_symmetry():
"""Test symmetry property of complete WASABITI spectra."""
offsets, b0_shift, rb1, t1 = create_data()
offsets, b0_shift, relative_b1, t1 = create_data()
wasabiti_model = WASABITI(offsets=offsets, recovery_time=torch.ones_like(offsets))
(signal,) = wasabiti_model(b0_shift, rb1, t1)
(signal,) = wasabiti_model(b0_shift, relative_b1, t1)

# check that all values are symmetric around the center
assert torch.allclose(signal, signal.flipud(), rtol=1e-15), 'Result should be symmetric around center'


def test_WASABITI_symmetry_after_shift():
"""Test symmetry property of shifted WASABITI spectra."""
offsets_shifted, b0_shift, rb1, t1 = create_data(b0_shift=100)
trec = torch.ones_like(offsets_shifted)
wasabiti_model = WASABITI(offsets=offsets_shifted, recovery_time=trec)
(signal_shifted,) = wasabiti_model(b0_shift, rb1, t1)
offsets_shifted, b0_shift, relative_b1, t1 = create_data(b0_shift=100)
recovery_time = torch.ones_like(offsets_shifted)
wasabiti_model = WASABITI(offsets=offsets_shifted, recovery_time=recovery_time)
(signal_shifted,) = wasabiti_model(b0_shift, relative_b1, t1)

lower_index = int((offsets_shifted == -300).nonzero()[0][0])
upper_index = int((offsets_shifted == 500).nonzero()[0][0])

assert signal_shifted[lower_index] == signal_shifted[upper_index], 'Result should be symmetric around shift'


def test_WASABITI_asymmetry_for_non_unique_trec():
"""Test symmetry property of WASABITI spectra for non-unique trec values."""
offsets_unshifted, b0_shift, rb1, t1 = create_data(n_offsets=11)
trec = torch.ones_like(offsets_unshifted)
# set first half of trec values to 2.0
trec[: len(offsets_unshifted) // 2] = 2.0
def test_WASABITI_asymmetry_for_non_unique_recovery_time():
"""Test symmetry property of WASABITI spectra for non-unique recovery_time values."""
offsets_unshifted, b0_shift, relative_b1, t1 = create_data(n_offsets=11)
recovery_time = torch.ones_like(offsets_unshifted)
# set first half of recovery_time values to 2.0
recovery_time[: len(offsets_unshifted) // 2] = 2.0

wasabiti_model = WASABITI(offsets=offsets_unshifted, recovery_time=trec)
(signal,) = wasabiti_model(b0_shift, rb1, t1)
wasabiti_model = WASABITI(offsets=offsets_unshifted, recovery_time=recovery_time)
(signal,) = wasabiti_model(b0_shift, relative_b1, t1)

assert not torch.allclose(signal, signal.flipud(), rtol=1e-8), 'Result should not be symmetric around center'


@pytest.mark.parametrize('t1', [(1), (2), (3)])
def test_WASABITI_relaxation_term(t1):
"""Test relaxation term (Mzi) of WASABITI model."""
offset, b0_shift, rb1, t1 = create_data(offset_max=50000, n_offsets=1, t1=t1)
trec = torch.ones_like(offset) * t1
wasabiti_model = WASABITI(offsets=offset, recovery_time=trec)
sig = wasabiti_model(b0_shift, rb1, t1)
offset, b0_shift, relative_b1, t1 = create_data(offset_max=50000, n_offsets=1, t1=t1)
recovery_time = torch.ones_like(offset) * t1
wasabiti_model = WASABITI(offsets=offset, recovery_time=recovery_time)
sig = wasabiti_model(b0_shift, relative_b1, t1)

assert torch.isclose(sig[0], torch.FloatTensor([1 - torch.exp(torch.FloatTensor([-1]))]), rtol=1e-8)


def test_WASABITI_offsets_trec_mismatch():
def test_WASABITI_offsets_recovery_time_mismatch():
"""Verify error for shape mismatch."""
offsets = torch.ones((1, 2))
trec = torch.ones((1,))
with pytest.raises(ValueError, match='Shape of trec'):
WASABITI(offsets=offsets, recovery_time=trec)
recovery_time = torch.ones((1,))
with pytest.raises(ValueError, match='Shape of recovery_time'):
WASABITI(offsets=offsets, recovery_time=recovery_time)


@SHAPE_VARIATIONS_SIGNAL_MODELS
def test_WASABITI_shape(parameter_shape, contrast_dim_shape, signal_shape):
"""Test correct signal shapes."""
ti, trec = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=2)
model_op = WASABITI(ti, trec)
b0_shift, rb1, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3)
(signal,) = model_op(b0_shift, rb1, t1)
ti, recovery_time = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=2)
model_op = WASABITI(ti, recovery_time)
b0_shift, relative_b1, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3)
(signal,) = model_op(b0_shift, relative_b1, t1)
assert signal.shape == signal_shape


def test_autodiff_WASABITI():
"""Test autodiff works for WASABITI model."""
offset, b0_shift, rb1, t1 = create_data(offset_max=300, n_offsets=2)
trec = torch.ones_like(offset) * t1
wasabiti_model = WASABITI(offsets=offset, recovery_time=trec)
autodiff_test(wasabiti_model, b0_shift, rb1, t1)
offset, b0_shift, relative_b1, t1 = create_data(offset_max=300, n_offsets=2)
recovery_time = torch.ones_like(offset) * t1
wasabiti_model = WASABITI(offsets=offset, recovery_time=recovery_time)
autodiff_test(wasabiti_model, b0_shift, relative_b1, t1)

0 comments on commit 9d7dc8d

Please sign in to comment.