From d7521bd0f67856e4462737066a534ee8991dcddc Mon Sep 17 00:00:00 2001 From: Pierre-Antoine Comby Date: Wed, 21 Feb 2024 11:53:05 +0100 Subject: [PATCH] feat: add batch support for gpunufft. (#81) * feat: add batch support for gpunufft. * Make test density stronger * Fix testing * Bump gpuNUFFT supported version * Make sure test does not run for other than gpuNUFFT * Fix voronoi * ruff * Fixes --------- Co-authored-by: GILIYAR RADHAKRISHNA Chaithya --- pyproject.toml | 2 +- src/mrinufft/operators/interfaces/gpunufft.py | 83 ++++++++++++++----- tests/case_trajectories.py | 13 ++- tests/helpers/asserts.py | 8 +- tests/test_batch.py | 5 +- tests/test_density.py | 43 ++++++---- 6 files changed, 117 insertions(+), 37 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bb12bafb3..759b54038 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dynamic = ["version"] [project.optional-dependencies] -gpunufft = ["gpuNUFFT>=0.7.1", "cupy-cuda11x"] +gpunufft = ["gpuNUFFT>=0.7.4", "cupy-cuda11x"] cufinufft = ["cufinufft", "cupy-cuda11x"] finufft = ["finufft"] pynfft = ["pynfft2", "cython<3.0.0"] diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index 0ae24ca9c..7cc5d9261 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -170,7 +170,7 @@ def __init__( ) self.pinned_image = pinned_image self.pinned_kspace = pinned_kspace - + self.osf = osf self.pinned_smaps = pinned_smaps self.operator = NUFFTOp( np.reshape(samples, samples.shape[::-1], order="F"), @@ -193,8 +193,9 @@ def _reshape_image(self, image, direction="op"): return xp.asarray([c.ravel(order="F") for c in image], dtype=xp.complex64).T else: if self.uses_sense or self.n_coils == 1: - return image.squeeze().astype(xp.complex64).T - return xp.asarray([c.T for c in image], dtype=xp.complex64) + # Support for one additional dimension + return image.squeeze().astype(xp.complex64).T[None] + return xp.asarray([c.T for c in image], dtype=xp.complex64).squeeze() def op_direct(self, image, kspace=None, interpolate_data=False): """Compute the masked non-Cartesian Fourier transform. @@ -304,15 +305,16 @@ def adj_op_direct(self, coeffs, image=None, grid_data=False): adjoint operator of Non Uniform Fourier transform of the input coefficients. """ + C = 1 if self.uses_sense else self.n_coils coeffs = coeffs.astype(cp.complex64) if image is None: image = cp.empty( - (np.prod(self.shape), (1 if self.uses_sense else self.n_coils)), + (np.prod(self.shape), C), dtype=cp.complex64, order="F", ) self.operator.adj_op_direct(coeffs.data.ptr, image.data.ptr, grid_data) - image = image.reshape(self.n_coils, *self.shape[::-1]) + image = image.reshape(C, *self.shape[::-1]) return self._reshape_image(image, "adjoint") @@ -334,10 +336,12 @@ class MRIGpuNUFFT(FourierOperatorBase): if True, the density compensation is estimated from the samples locations. If an array is passed, it is used as the density compensation. - squeeze_dims: bool default True - This has no effect, gpuNUFFT always squeeze the data. + squeeze_dims: bool, default True + If True, will try to remove the singleton dimension for batch and coils. smaps: np.ndarray default None Holds the sensitivity maps for SENSE reconstruction. + n_trans: int, default =1 + This has no effect for now. kwargs: extra keyword args these arguments are passed to gpuNUFFT operator. This is used only in gpuNUFFT @@ -351,9 +355,11 @@ def __init__( samples, shape, n_coils=1, + n_batchs=1, + n_trans=1, density=None, smaps=None, - squeeze_dims=False, + squeeze_dims=True, eps=1e-3, **kwargs, ): @@ -367,8 +373,9 @@ def __init__( self.samples = proper_trajectory(samples, normalize="unit") self.dtype = self.samples.dtype self.n_coils = n_coils + self.n_batchs = n_batchs self.smaps = smaps - + self.squeeze_dims = squeeze_dims self.compute_density(density) self.impl = RawGpuNUFFT( samples=self.samples, @@ -396,21 +403,32 @@ def op(self, data, coeffs=None): np.ndarray Masked Fourier transform of the input image. """ + B, C, XYZ, K = self.n_batchs, self.n_coils, self.shape, self.n_samples + + op_func = self.impl.op if is_cuda_array(data): + op_func = self.impl.op_direct if not self.impl.use_gpu_direct: warnings.warn( "Using direct GPU array without passing " "`use_gpu_direct=True`, this is memory inefficient." ) - return self.impl.op_direct(data, coeffs) - return self.impl.op( - data, - coeffs, - ) + data_ = data.reshape((B, 1 if self.uses_sense else C, *XYZ)) + if coeffs is not None: + coeffs.reshape((B, C, K)) + result = [] + for i in range(B): + if coeffs is None: + result.append(op_func(data_[i], None)) + else: + op_func(data_[i], coeffs[i]) + if coeffs is None: + coeffs = get_array_module(data).stack(result) + return self._safe_squeeze(coeffs) @with_numpy_cupy def adj_op(self, coeffs, data=None): - """Compute adjoint Non Unform Fourier Transform. + """Compute adjoint Non Uniform Fourier Transform. Parameters ---------- @@ -424,14 +442,28 @@ def adj_op(self, coeffs, data=None): np.ndarray Inverse discrete Fourier transform of the input coefficients. """ + B, C, XYZ, K = self.n_batchs, self.n_coils, self.shape, self.n_samples + + adj_op_func = self.impl.adj_op if is_cuda_array(coeffs): + adj_op_func = self.impl.adj_op_direct if not self.impl.use_gpu_direct: warnings.warn( "Using direct GPU array without passing " "`use_gpu_direct=True`, this is memory inefficient." ) - return self.impl.adj_op_direct(coeffs, data) - return self.impl.adj_op(coeffs, data) + coeffs_ = coeffs.reshape(B, C, K) + if data is not None: + data.reshape((B, 1 if self.uses_sense else C, *XYZ)) + result = [] + for i in range(B): + if data is None: + result.append(adj_op_func(coeffs_[i], None)) + else: + adj_op_func(coeffs_[i], data[i]) + if data is None: + data = get_array_module(coeffs).stack(result) + return self._safe_squeeze(data) @property def uses_sense(self): @@ -457,10 +489,11 @@ def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, **kwargs): raise ValueError( "gpuNUFFT is not available, cannot " "estimate the density compensation" ) + volume_shape = (np.array(volume_shape) * osf).astype(int) grid_op = MRIGpuNUFFT( samples=kspace_loc, shape=volume_shape, - osf=osf, + osf=1, **kwargs, ) density_comp = grid_op.impl.operator.estimate_density_comp( @@ -471,7 +504,6 @@ def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, **kwargs): def get_lipschitz_cst(self, max_iter=10, tolerance=1e-5, **kwargs): """Return the Lipschitz constant of the operator. - Parameters ---------- max_iter: int Number of iteration to perform to estimate the Lipschitz constant. @@ -497,3 +529,16 @@ def get_lipschitz_cst(self, max_iter=10, tolerance=1e-5, **kwargs): return tmp_op.impl.operator.get_spectral_radius( max_iter=max_iter, tolerance=tolerance ) + + def _safe_squeeze(self, arr): + """Squeeze the first two dimensions of shape of the operator.""" + if self.squeeze_dims: + try: + arr = arr.squeeze(axis=1) + except ValueError: + pass + try: + arr = arr.squeeze(axis=0) + except ValueError: + pass + return arr diff --git a/tests/case_trajectories.py b/tests/case_trajectories.py index 14dd66582..e44265797 100644 --- a/tests/case_trajectories.py +++ b/tests/case_trajectories.py @@ -4,7 +4,7 @@ import scipy as sp from mrinufft.trajectories import initialize_2D_radial -from mrinufft.trajectories.tools import stack +from mrinufft.trajectories.tools import stack, rotate class CasesTrajectories: @@ -34,12 +34,23 @@ def case_radial2D(self, Nc=10, Ns=500, N=64): trajectory = initialize_2D_radial(Nc, Ns) return trajectory, (N, N) + def case_nyquist_radial2D(self, Nc=32 * 4, Ns=16, N=32): + """Create a 2D radial trajectory.""" + trajectory = initialize_2D_radial(Nc, Ns) + return trajectory, (N, N) + def case_radial3D(self, Nc=20, Ns=1000, Nr=20, N=64, expansion="rotations"): """Create a 3D radial trajectory.""" trajectory = initialize_2D_radial(Nc, Ns) trajectory = stack(trajectory, nb_stacks=Nr) return trajectory, (N, N, N) + def case_nyquist_radial3D(self, Nc=32 * 4, Ns=16, Nr=32 * 4, N=32): + """Create a 3D radial trajectory.""" + trajectory = initialize_2D_radial(Nc, Ns) + trajectory = rotate(trajectory, nb_rotations=Nr) + return trajectory, (N, N, N) + def case_grid2D(self, N=16): """Create a 2D cartesian grid of frequencies locations.""" freq_1d = sp.fft.fftfreq(N) diff --git a/tests/helpers/asserts.py b/tests/helpers/asserts.py index 9198c8063..3d7d3113c 100644 --- a/tests/helpers/asserts.py +++ b/tests/helpers/asserts.py @@ -51,7 +51,13 @@ def assert_correlate(a, b, slope=1.0, slope_err=1e-3, r_value_err=1e-3): a.flatten(), b.flatten() ) abs_slope_reg = abs(slope_reg) - if abs(abs_slope_reg - slope) > slope_err: + if r_value_err is not None and abs(rvalue - 1) > r_value_err: + raise AssertionError( + f"RValue {rvalue} != 1 +- {r_value_err}\n " + f"intercept={intercept}, stderr={stderr}, " + f"intercept_stderr={intercept_stderr}" + ) + if slope_err is not None and abs(abs_slope_reg - slope) > slope_err: raise AssertionError( f"Slope {abs_slope_reg} != {slope} +- {slope_err}\n r={rvalue}," f"intercept={intercept}, stderr={stderr}, " diff --git a/tests/test_batch.py b/tests/test_batch.py index d1698c8d4..ba26d884a 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -5,6 +5,7 @@ import numpy as np import numpy.testing as npt +import pytest from pytest_cases import parametrize_with_cases, parametrize, fixture from helpers import ( assert_correlate, @@ -37,7 +38,7 @@ cases=CasesTrajectories, glob="*random*", ) -@parametrize(backend=["finufft", "cufinufft"]) +@parametrize(backend=["gpunufft", "finufft", "cufinufft"]) def operator( request, kspace_locs, @@ -49,6 +50,8 @@ def operator( backend="finufft", ): """Generate a batch operator.""" + if n_trans != 1 and backend == "gpunufft": + pytest.skip("Duplicate case.") if sense: smaps = 1j * np.random.rand(n_coils, *shape) smaps += np.random.rand(n_coils, *shape) diff --git a/tests/test_density.py b/tests/test_density.py index 771557394..2cba4c9b3 100644 --- a/tests/test_density.py +++ b/tests/test_density.py @@ -2,11 +2,11 @@ import numpy as np import numpy.testing as npt -from pytest_cases import fixture, parametrize, parametrize_with_cases +from pytest_cases import parametrize, parametrize_with_cases from case_trajectories import CasesTrajectories from helpers import assert_correlate -from mrinufft.density import cell_count, voronoi +from mrinufft.density import cell_count, voronoi, pipe from mrinufft.density.utils import normalize_weights from mrinufft._utils import proper_trajectory @@ -37,15 +37,11 @@ def slow_cell_count2D(traj, shape, osf): return normalize_weights(weights) -@fixture(scope="module") -def radial_distance(): +def radial_distance(traj, shape): """Compute the radial distance of a trajectory.""" - traj, shape = CasesTrajectories().case_radial2D() - proper_traj = proper_trajectory(traj, normalize="unit") - weights = 2 * np.pi * np.sqrt(proper_traj[:, 0] ** 2 + proper_traj[:, 1] ** 2) - - return normalize_weights(weights) + weights = np.linalg.norm(proper_traj, axis=-1) + return weights @parametrize("osf", [1, 1.25, 2]) @@ -58,13 +54,32 @@ def test_cell_count2D(traj, shape, osf): @parametrize_with_cases("traj, shape", cases=[CasesTrajectories.case_radial2D]) -def test_voronoi(traj, shape, radial_distance): +def test_voronoi(traj, shape): """Test the voronoi method.""" result = voronoi(traj) + distance = radial_distance(traj, shape) + result = result / np.mean(result) + distance = distance / np.mean(distance) + assert_correlate(result, distance, slope=1) - assert_correlate(result, radial_distance, slope=2 * np.pi) - -def test_pipe(): +@parametrize("osf", [1, 1.25, 2]) +@parametrize_with_cases( + "traj, shape", + cases=[ + CasesTrajectories.case_nyquist_radial2D, + CasesTrajectories.case_nyquist_radial3D, + ], +) +@parametrize(backend=["gpunufft"]) +def test_pipe(backend, traj, shape, osf): """Test the pipe method.""" - pass + distance = radial_distance(traj, shape) + result = pipe(traj, shape, osf=osf, num_iterations=10) + result = result / np.mean(result) + distance = distance / np.mean(distance) + if osf != 2: + # If OSF < 2, we dont perfectly estimate + assert_correlate(result, distance, slope=1, slope_err=None, r_value_err=0.2) + else: + assert_correlate(result, distance, slope=1, slope_err=0.1, r_value_err=0.1)