From aff63fbf1cd565b7c88f0d8a964492bbe2e58566 Mon Sep 17 00:00:00 2001 From: Felix F Zimmermann Date: Wed, 15 Jan 2025 16:12:41 +0100 Subject: [PATCH] Rename WASABI(TI) parameters (#606) --- src/mrpro/operators/models/WASABI.py | 46 ++++++++--------- src/mrpro/operators/models/WASABITI.py | 64 ++++++++++++------------ tests/operators/models/test_wasabiti.py | 66 ++++++++++++------------- 3 files changed, 89 insertions(+), 87 deletions(-) diff --git a/src/mrpro/operators/models/WASABI.py b/src/mrpro/operators/models/WASABI.py index 207b63c0e..a9c51e723 100644 --- a/src/mrpro/operators/models/WASABI.py +++ b/src/mrpro/operators/models/WASABI.py @@ -13,10 +13,10 @@ class WASABI(SignalModel[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] def __init__( self, offsets: torch.Tensor, - tp: float | torch.Tensor = 0.005, - b1_nom: float | torch.Tensor = 3.70, + rf_duration: float | torch.Tensor = 0.005, + b1_nominal: float | torch.Tensor = 3.70, gamma: float | torch.Tensor = 42.5764, - freq: float | torch.Tensor = 127.7292, + larmor_frequency: float | torch.Tensor = 127.7292, ) -> None: """Initialize WASABI signal model for mapping of B0 and B1 [SCHU2016]_. @@ -24,14 +24,14 @@ def __init__( ---------- offsets frequency offsets [Hz] - with shape (offsets, ...) - tp + with shape `(offsets, ...)` + rf_duration RF pulse duration [s] - b1_nom + b1_nominal nominal B1 amplitude [µT] gamma gyromagnetic ratio [MHz/T] - freq + larmor_frequency larmor frequency [MHz] References @@ -40,18 +40,18 @@ def __init__( field-Inhomogeneity correction of CEST MRI data. MRM 77(2). https://doi.org/10.1002/mrm.26133 """ super().__init__() - # convert all parameters to tensors - tp = torch.as_tensor(tp) - b1_nom = torch.as_tensor(b1_nom) + + rf_duration = torch.as_tensor(rf_duration) + b1_nominal = torch.as_tensor(b1_nominal) gamma = torch.as_tensor(gamma) - freq = torch.as_tensor(freq) + larmor_frequency = torch.as_tensor(larmor_frequency) # nn.Parameters allow for grad calculation self.offsets = nn.Parameter(offsets, requires_grad=offsets.requires_grad) - self.tp = nn.Parameter(tp, requires_grad=tp.requires_grad) - self.b1_nom = nn.Parameter(b1_nom, requires_grad=b1_nom.requires_grad) + self.rf_duration = nn.Parameter(rf_duration, requires_grad=rf_duration.requires_grad) + self.b1_nominal = nn.Parameter(b1_nominal, requires_grad=b1_nominal.requires_grad) self.gamma = nn.Parameter(gamma, requires_grad=gamma.requires_grad) - self.freq = nn.Parameter(freq, requires_grad=freq.requires_grad) + self.larmor_frequency = nn.Parameter(larmor_frequency, requires_grad=larmor_frequency.requires_grad) def forward( self, @@ -66,29 +66,29 @@ def forward( ---------- b0_shift B0 shift [Hz] - with shape (... other, coils, z, y, x) + with shape `(*other, coils, z, y, x)` relative_b1 relative B1 amplitude - with shape (... other, coils, z, y, x) + with shape `(*other, coils, z, y, x)` c additional fit parameter for the signal model - with shape (... other, coils, z, y, x) + with shape `(*other, coils, z, y, x)` d additional fit parameter for the signal model - with shape (... other, coils, z, y, x) + with shape `(*other, coils, z, y, x)` Returns ------- - signal with shape (offsets ... other, coils, z, y, x) + signal with shape `(offsets *other, coils, z, y, x)` """ offsets = unsqueeze_right(self.offsets, b0_shift.ndim - (self.offsets.ndim - 1)) # -1 for offset - delta_x = offsets - b0_shift - b1 = self.b1_nom * relative_b1 + delta_b0 = offsets - b0_shift + b1 = self.b1_nominal * relative_b1 signal = ( c - d - * (torch.pi * b1 * self.gamma * self.tp) ** 2 - * torch.sinc(self.tp * torch.sqrt((b1 * self.gamma) ** 2 + delta_x**2)) ** 2 + * (torch.pi * b1 * self.gamma * self.rf_duration) ** 2 + * torch.sinc(self.rf_duration * torch.sqrt((b1 * self.gamma) ** 2 + delta_b0**2)) ** 2 ) return (signal,) diff --git a/src/mrpro/operators/models/WASABITI.py b/src/mrpro/operators/models/WASABITI.py index ee1e4ae31..5a7a59a04 100644 --- a/src/mrpro/operators/models/WASABITI.py +++ b/src/mrpro/operators/models/WASABITI.py @@ -13,27 +13,27 @@ class WASABITI(SignalModel[torch.Tensor, torch.Tensor, torch.Tensor]): def __init__( self, offsets: torch.Tensor, - trec: torch.Tensor, - tp: float | torch.Tensor = 0.005, - b1_nom: float | torch.Tensor = 3.75, + recovery_time: torch.Tensor, + rf_duration: float | torch.Tensor = 0.005, + b1_nominal: float | torch.Tensor = 3.75, gamma: float | torch.Tensor = 42.5764, - freq: float | torch.Tensor = 127.7292, + larmor_frequency: float | torch.Tensor = 127.7292, ) -> None: """Initialize WASABITI signal model for mapping of B0, B1 and T1 [SCH2023]_. Parameters ---------- offsets - frequency offsets [Hz] with shape (offsets, ...) - trec - recovery time between offsets [s] with shape (offsets, ...) - tp + frequency offsets [Hz] with shape `(offsets, ...)` + recovery_time + recovery time between offsets [s] with shape `(offsets, ...)` + rf_duration RF pulse duration [s] - b1_nom + b1_nominal nominal B1 amplitude [µT] gamma gyromagnetic ratio [MHz/T] - freq + larmor_frequency larmor frequency [MHz] References @@ -44,53 +44,55 @@ def __init__( """ super().__init__() # convert all parameters to tensors - tp = torch.as_tensor(tp) - b1_nom = torch.as_tensor(b1_nom) + rf_duration = torch.as_tensor(rf_duration) + b1_nominal = torch.as_tensor(b1_nominal) gamma = torch.as_tensor(gamma) - freq = torch.as_tensor(freq) + larmor_frequency = torch.as_tensor(larmor_frequency) - if trec.shape != offsets.shape: - raise ValueError(f'Shape of trec ({trec.shape}) and offsets ({offsets.shape}) needs to be the same.') + if recovery_time.shape != offsets.shape: + raise ValueError( + f'Shape of recovery_time ({recovery_time.shape}) and offsets ({offsets.shape}) needs to be the same.' + ) # nn.Parameters allow for grad calculation self.offsets = nn.Parameter(offsets, requires_grad=offsets.requires_grad) - self.trec = nn.Parameter(trec, requires_grad=trec.requires_grad) - self.tp = nn.Parameter(tp, requires_grad=tp.requires_grad) - self.b1_nom = nn.Parameter(b1_nom, requires_grad=b1_nom.requires_grad) + self.recovery_time = nn.Parameter(recovery_time, requires_grad=recovery_time.requires_grad) + self.rf_duration = nn.Parameter(rf_duration, requires_grad=rf_duration.requires_grad) + self.b1_nominal = nn.Parameter(b1_nominal, requires_grad=b1_nominal.requires_grad) self.gamma = nn.Parameter(gamma, requires_grad=gamma.requires_grad) - self.freq = nn.Parameter(freq, requires_grad=freq.requires_grad) + self.larmor_frequency = nn.Parameter(larmor_frequency, requires_grad=larmor_frequency.requires_grad) - def forward(self, b0_shift: torch.Tensor, rb1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: + def forward(self, b0_shift: torch.Tensor, relative_b1: torch.Tensor, t1: torch.Tensor) -> tuple[torch.Tensor,]: """Apply WASABITI signal model. Parameters ---------- b0_shift B0 shift [Hz] - with shape (... other, coils, z, y, x) - rb1 + with shape `(*other, coils, z, y, x)` + relative_b1 relative B1 amplitude - with shape (... other, coils, z, y, x) + with shape `(*other, coils, z, y, x)` t1 longitudinal relaxation time T1 [s] - with shape (... other, coils, z, y, x) + with shape `(*other, coils, z, y, x)` Returns ------- - signal with shape (offsets ... other, coils, z, y, x) + signal with shape `(offsets *other, coils, z, y, x)` """ delta_ndim = b0_shift.ndim - (self.offsets.ndim - 1) # -1 for offset offsets = unsqueeze_right(self.offsets, delta_ndim) - trec = unsqueeze_right(self.trec, delta_ndim) + recovery_time = unsqueeze_right(self.recovery_time, delta_ndim) - b1 = self.b1_nom * rb1 - da = offsets - b0_shift - mz_initial = 1.0 - torch.exp(-trec / t1) + b1 = self.b1_nominal * relative_b1 + delta_b0 = offsets - b0_shift + mz_initial = 1.0 - torch.exp(-recovery_time / t1) signal = mz_initial * ( 1 - 2 - * (torch.pi * b1 * self.gamma * self.tp) ** 2 - * torch.sinc(self.tp * torch.sqrt((b1 * self.gamma) ** 2 + da**2)) ** 2 + * (torch.pi * b1 * self.gamma * self.rf_duration) ** 2 + * torch.sinc(self.rf_duration * torch.sqrt((b1 * self.gamma) ** 2 + delta_b0**2)) ** 2 ) return (signal,) diff --git a/tests/operators/models/test_wasabiti.py b/tests/operators/models/test_wasabiti.py index 637f9ff9e..cde3753fc 100644 --- a/tests/operators/models/test_wasabiti.py +++ b/tests/operators/models/test_wasabiti.py @@ -8,17 +8,17 @@ def create_data( - offset_max=500, n_offsets=101, b0_shift=0, rb1=1.0, t1=1.0 + offset_max=500, n_offsets=101, b0_shift=0, relative_b1=1.0, t1=1.0 ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: offsets = torch.linspace(-offset_max, offset_max, n_offsets) - return offsets, torch.Tensor([b0_shift]), torch.Tensor([rb1]), torch.Tensor([t1]) + return offsets, torch.Tensor([b0_shift]), torch.Tensor([relative_b1]), torch.Tensor([t1]) def test_WASABITI_symmetry(): """Test symmetry property of complete WASABITI spectra.""" - offsets, b0_shift, rb1, t1 = create_data() - wasabiti_model = WASABITI(offsets=offsets, trec=torch.ones_like(offsets)) - (signal,) = wasabiti_model(b0_shift, rb1, t1) + offsets, b0_shift, relative_b1, t1 = create_data() + wasabiti_model = WASABITI(offsets=offsets, recovery_time=torch.ones_like(offsets)) + (signal,) = wasabiti_model(b0_shift, relative_b1, t1) # check that all values are symmetric around the center assert torch.allclose(signal, signal.flipud(), rtol=1e-15), 'Result should be symmetric around center' @@ -26,10 +26,10 @@ def test_WASABITI_symmetry(): def test_WASABITI_symmetry_after_shift(): """Test symmetry property of shifted WASABITI spectra.""" - offsets_shifted, b0_shift, rb1, t1 = create_data(b0_shift=100) - trec = torch.ones_like(offsets_shifted) - wasabiti_model = WASABITI(offsets=offsets_shifted, trec=trec) - (signal_shifted,) = wasabiti_model(b0_shift, rb1, t1) + offsets_shifted, b0_shift, relative_b1, t1 = create_data(b0_shift=100) + recovery_time = torch.ones_like(offsets_shifted) + wasabiti_model = WASABITI(offsets=offsets_shifted, recovery_time=recovery_time) + (signal_shifted,) = wasabiti_model(b0_shift, relative_b1, t1) lower_index = int((offsets_shifted == -300).nonzero()[0][0]) upper_index = int((offsets_shifted == 500).nonzero()[0][0]) @@ -37,15 +37,15 @@ def test_WASABITI_symmetry_after_shift(): assert signal_shifted[lower_index] == signal_shifted[upper_index], 'Result should be symmetric around shift' -def test_WASABITI_asymmetry_for_non_unique_trec(): - """Test symmetry property of WASABITI spectra for non-unique trec values.""" - offsets_unshifted, b0_shift, rb1, t1 = create_data(n_offsets=11) - trec = torch.ones_like(offsets_unshifted) - # set first half of trec values to 2.0 - trec[: len(offsets_unshifted) // 2] = 2.0 +def test_WASABITI_asymmetry_for_non_unique_recovery_time(): + """Test symmetry property of WASABITI spectra for non-unique recovery_time values.""" + offsets_unshifted, b0_shift, relative_b1, t1 = create_data(n_offsets=11) + recovery_time = torch.ones_like(offsets_unshifted) + # set first half of recovery_time values to 2.0 + recovery_time[: len(offsets_unshifted) // 2] = 2.0 - wasabiti_model = WASABITI(offsets=offsets_unshifted, trec=trec) - (signal,) = wasabiti_model(b0_shift, rb1, t1) + wasabiti_model = WASABITI(offsets=offsets_unshifted, recovery_time=recovery_time) + (signal,) = wasabiti_model(b0_shift, relative_b1, t1) assert not torch.allclose(signal, signal.flipud(), rtol=1e-8), 'Result should not be symmetric around center' @@ -53,35 +53,35 @@ def test_WASABITI_asymmetry_for_non_unique_trec(): @pytest.mark.parametrize('t1', [(1), (2), (3)]) def test_WASABITI_relaxation_term(t1): """Test relaxation term (Mzi) of WASABITI model.""" - offset, b0_shift, rb1, t1 = create_data(offset_max=50000, n_offsets=1, t1=t1) - trec = torch.ones_like(offset) * t1 - wasabiti_model = WASABITI(offsets=offset, trec=trec) - sig = wasabiti_model(b0_shift, rb1, t1) + offset, b0_shift, relative_b1, t1 = create_data(offset_max=50000, n_offsets=1, t1=t1) + recovery_time = torch.ones_like(offset) * t1 + wasabiti_model = WASABITI(offsets=offset, recovery_time=recovery_time) + sig = wasabiti_model(b0_shift, relative_b1, t1) assert torch.isclose(sig[0], torch.FloatTensor([1 - torch.exp(torch.FloatTensor([-1]))]), rtol=1e-8) -def test_WASABITI_offsets_trec_mismatch(): +def test_WASABITI_offsets_recovery_time_mismatch(): """Verify error for shape mismatch.""" offsets = torch.ones((1, 2)) - trec = torch.ones((1,)) - with pytest.raises(ValueError, match='Shape of trec'): - WASABITI(offsets=offsets, trec=trec) + recovery_time = torch.ones((1,)) + with pytest.raises(ValueError, match='Shape of recovery_time'): + WASABITI(offsets=offsets, recovery_time=recovery_time) @SHAPE_VARIATIONS_SIGNAL_MODELS def test_WASABITI_shape(parameter_shape, contrast_dim_shape, signal_shape): """Test correct signal shapes.""" - ti, trec = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=2) - model_op = WASABITI(ti, trec) - b0_shift, rb1, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) - (signal,) = model_op(b0_shift, rb1, t1) + ti, recovery_time = create_parameter_tensor_tuples(contrast_dim_shape, number_of_tensors=2) + model_op = WASABITI(ti, recovery_time) + b0_shift, relative_b1, t1 = create_parameter_tensor_tuples(parameter_shape, number_of_tensors=3) + (signal,) = model_op(b0_shift, relative_b1, t1) assert signal.shape == signal_shape def test_autodiff_WASABITI(): """Test autodiff works for WASABITI model.""" - offset, b0_shift, rb1, t1 = create_data(offset_max=300, n_offsets=2) - trec = torch.ones_like(offset) * t1 - wasabiti_model = WASABITI(offsets=offset, trec=trec) - autodiff_test(wasabiti_model, b0_shift, rb1, t1) + offset, b0_shift, relative_b1, t1 = create_data(offset_max=300, n_offsets=2) + recovery_time = torch.ones_like(offset) * t1 + wasabiti_model = WASABITI(offsets=offset, recovery_time=recovery_time) + autodiff_test(wasabiti_model, b0_shift, relative_b1, t1)