From 3ed0924ad1e7771a2b71a9196e590bcd1b0793f8 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 18 Mar 2024 16:17:55 +0100 Subject: [PATCH 1/5] feat: allow reuse of 2D operator in stacked nufft. --- src/mrinufft/operators/stacked.py | 81 ++++++++++++++++++++----------- 1 file changed, 53 insertions(+), 28 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index a657d6ea..ed9be118 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -6,7 +6,12 @@ import scipy as sp from mrinufft._utils import proper_trajectory, power_method, get_array_module, auto_cast -from mrinufft.operators.base import FourierOperatorBase, check_backend, get_operator +from mrinufft.operators.base import ( + FourierOperatorBase, + check_backend, + get_operator, + with_numpy_cupy, +) from mrinufft.operators.interfaces.utils import ( is_cuda_array, is_host_array, @@ -36,8 +41,13 @@ class MRIStackedNUFFT(FourierOperatorBase): Shape of the image. z_index: array-like Cartesian z index of masked plan. - backend: str + backend: str or FourierOperatorBase Backend to use. + If str, a NUFFT operator is initialized with str being a registered backend. + If FourierOperatorBase, the operator is checked for compatibility and used as is. + notably one should have: + ``n_coils = self.n_coils*len(z_index), squeeze_dims=True, smaps=None`` + smaps: array-like Sensitivity maps. n_coils: int @@ -73,22 +83,48 @@ def __init__( **kwargs, ): super().__init__() - samples2d, z_index_ = self._init_samples(samples, z_index, shape) self.shape = shape - self.samples = samples2d.reshape(-1, 2) - self.z_index = z_index_ self.n_coils = n_coils self.n_batchs = n_batchs self.squeeze_dims = squeeze_dims self.smaps = smaps - self.operator = get_operator(backend)( - self.samples, - shape[:-1], - n_coils=self.n_coils * len(self.z_index), - smaps=None, - squeeze_dims=True, - **kwargs, - ) + if isinstance(backend, str): + samples2d, z_index_ = self._init_samples(samples, z_index, shape) + self.samples = samples2d.reshape(-1, 2) + self.z_index = z_index_ + self.operator = get_operator(backend)( + self.samples, + shape[:-1], + n_coils=self.n_coils * len(self.z_index), + smaps=None, + squeeze_dims=True, + **kwargs, + ) + elif isinstance(backend, FourierOperatorBase): + # get all the interesting values from the operator + if backend.shape != shape[:-1]: + raise ValueError("Backend operator should have compatible shape") + + samples2d, z_index_ = self._init_samples(backend.samples, z_index, shape) + self.samples = samples2d.reshape(-1, 2) + self.z_index = z_index_ + + if backend.n_coils != self.n_coils * (len(z_index_)): + raise ValueError( + "The backend operator should have ``n_coils * len(z_index)``" + " specified for its coil dimension." + ) + if backend.uses_sense: + raise ValueError("Backend operator should not uses smaps.") + if not backend.squeeze_dims: + raise ValueError("Backend operator should have ``squeeze_dims=True``") + self.operator = backend + + else: + raise ValueError( + "backend should either be a 2D nufft operator," + " or a str specifying which nufft library to use." + ) @staticmethod def _init_samples(samples, z_index, shape): @@ -145,6 +181,7 @@ def _ifftz(data): sp.fft.ifft(sp.fft.ifftshift(data, axes=-1), axis=-1, norm="ortho"), axes=-1 ) / np.sqrt(2) + @with_numpy_cupy def op(self, data, ksp=None): """Forward operator.""" if self.uses_sense: @@ -189,6 +226,7 @@ def _op_calibless(self, data, ksp=None): ksp = ksp.reshape((B, C, NZ * NS)) return ksp + @with_numpy_cupy def adj_op(self, coeffs, img=None): """Adjoint operator.""" if self.uses_sense: @@ -367,12 +405,10 @@ def _ifftz(data): axes=-1, ) + @with_numpy_cupy def op(self, data, ksp=None): """Forward operator.""" # Dispatch to special case. - xp = get_array_module(data) - if xp.__name__ == "torch" and data.is_cpu: - data = data.numpy() data = auto_cast(data, self.cpx_dtype) if self.uses_sense and is_cuda_array(data): @@ -385,10 +421,6 @@ def op(self, data, ksp=None): op_func = self._op_calibless_host ret = op_func(data, ksp) - if xp.__name__ == "torch" and is_cuda_array(ret): - ret = xp.as_tensor(ret, device=data.device) - elif xp.__name__ == "torch": - ret = xp.from_numpy(ret) return self._safe_squeeze(ret) def _op_sense_host(self, data, ksp=None): @@ -527,12 +559,10 @@ def _op_calibless_device(self, data, ksp=None): ksp = ksp.reshape((B, C, NZ * NS)) return ksp + @with_numpy_cupy def adj_op(self, coeffs, img=None): """Adjoint operator.""" # Dispatch to special case. - xp = get_array_module(coeffs) - if xp.__name__ == "torch" and coeffs.is_cpu: - coeffs = coeffs.numpy() coeffs = auto_cast(coeffs, self.cpx_dtype) if self.uses_sense and is_cuda_array(coeffs): @@ -546,11 +576,6 @@ def adj_op(self, coeffs, img=None): ret = adj_op_func(coeffs, img) - if xp.__name__ == "torch" and is_cuda_array(ret): - ret = xp.as_tensor(ret, device=coeffs.device) - elif xp.__name__ == "torch": - ret = xp.from_numpy(ret) - return self._safe_squeeze(ret) def _adj_op_sense_host(self, coeffs, img_d=None): From b044ff054d8920eeb28d6a5757c8a288a86d06c4 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 19 Mar 2024 15:35:59 +0100 Subject: [PATCH 2/5] use shared operator for cufinufft as well. --- src/mrinufft/operators/stacked.py | 52 ++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index ed9be118..fb71daba 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -323,6 +323,7 @@ def __init__( squeeze_dims=False, smaps_cached=False, density=False, + backend="cufinufft", **kwargs, ): if not (CUPY_AVAILABLE and check_backend("cufinufft")): @@ -331,25 +332,49 @@ def __init__( if (n_batchs * n_coils) % n_trans != 0: raise ValueError("n_batchs * n_coils should be a multiple of n_transf") - samples2d, z_index_ = self._init_samples(samples, z_index, shape) self.shape = shape - self.samples = samples2d.reshape(-1, 2) - self.z_index = z_index_ self.n_coils = n_coils self.n_batchs = n_batchs self.n_trans = n_trans self.squeeze_dims = squeeze_dims + if isinstance(backend, str): + samples2d, z_index_ = self._init_samples(samples, z_index, shape) + self.samples = samples2d.reshape(-1, 2) + self.z_index = z_index_ + self.operator = get_operator(backend)( + self.samples, + shape[:-1], + n_coils=self.n_trans * len(self.z_index), + n_trans=len(self.z_index), + smaps=None, + squeeze_dims=True, + **kwargs, + ) + elif isinstance(backend, FourierOperatorBase): + # get all the interesting values from the operator + if backend.shape != shape[:-1]: + raise ValueError("Backend operator should have compatible shape") + + samples2d, z_index_ = self._init_samples(backend.samples, z_index, shape) + self.samples = samples2d.reshape(-1, 2) + self.z_index = z_index_ + + if backend.n_coils != self.n_trans * len(z_index_): + raise ValueError( + "The backend operator should have ``n_coils * len(z_index)``" + " specified for its coil dimension." + ) + if backend.uses_sense: + raise ValueError("Backend operator should not uses smaps.") + if not backend.squeeze_dims: + raise ValueError("Backend operator should have ``squeeze_dims=True``") + self.operator = backend + else: + raise ValueError( + "backend should either be a 2D nufft operator," + " or a str specifying which nufft library to use." + ) - self.operator = get_operator("cufinufft")( - self.samples, - shape[:-1], - n_coils=n_trans * len(self.z_index), - n_trans=len(self.z_index), - smaps=None, - squeeze_dims=True, - density=density, - **kwargs, - ) # Smaps support self.smaps = smaps self.smaps_cached = False @@ -358,7 +383,6 @@ def __init__( raise ValueError( "Smaps should be either a C-ordered ndarray, " "or a GPUArray." ) - self.smaps_cached = False if smaps_cached: warnings.warn( f"{sizeof_fmt(smaps.size * np.dtype(self.cpx_dtype).itemsize)}" From cbcc26697d4b48e0b98ab9372e4c0abe331a33b4 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Thu, 4 Apr 2024 14:04:56 +0200 Subject: [PATCH 3/5] add test for stacked reuse --- tests/test_stacked.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/test_stacked.py b/tests/test_stacked.py index d627f53e..b2199839 100644 --- a/tests/test_stacked.py +++ b/tests/test_stacked.py @@ -134,3 +134,20 @@ def test_stacked2traj3d(): npt.assert_allclose(traj2d, traj2d) npt.assert_allclose(z_index, z_index2) + + +def test_stack_reuse(operator, stacked_op): + """Test the reuse of the stacked operator.""" + + nufft_2d = stacked_op.operator + + reuse_op = MRIStackedNUFFT( + backend=nufft_2d, + shape=stacked_op.shape, + samples=stacked_op.samples, + z_index=stacked_op.z_index, + n_coils=stacked_op.n_coils, + n_batchs=stacked_op.n_batchs, + smaps=stacked_op.smaps, + ) + assert reuse_op.operator is nufft_2d From 630bfbd12055eee624b50f6303f51c84b9627a5c Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Fri, 5 Apr 2024 15:15:56 +0200 Subject: [PATCH 4/5] fix: E501 --- src/mrinufft/operators/stacked.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index fb71daba..d3611218 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -44,7 +44,7 @@ class MRIStackedNUFFT(FourierOperatorBase): backend: str or FourierOperatorBase Backend to use. If str, a NUFFT operator is initialized with str being a registered backend. - If FourierOperatorBase, the operator is checked for compatibility and used as is. + If FourierOperatorBase, operator is checked for compatibility and used as is notably one should have: ``n_coils = self.n_coils*len(z_index), squeeze_dims=True, smaps=None`` From 1ac1e26c5d108a7fe131e04d1b904992136f1dee Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 22 Apr 2024 13:02:32 +0200 Subject: [PATCH 5/5] fix: ruff. --- tests/operators/test_stacked.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/operators/test_stacked.py b/tests/operators/test_stacked.py index b2199839..f88dcfa7 100644 --- a/tests/operators/test_stacked.py +++ b/tests/operators/test_stacked.py @@ -138,7 +138,6 @@ def test_stacked2traj3d(): def test_stack_reuse(operator, stacked_op): """Test the reuse of the stacked operator.""" - nufft_2d = stacked_op.operator reuse_op = MRIStackedNUFFT(