diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 96f2b4f4..147e0001 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -119,6 +119,7 @@ jobs: cmake -DFINUFFT_USE_CUDA=1 ../ && cmake --build . && cp libcufinufft.so ../python/cufinufft/. # enter venv source $RUNNER_WORKSPACE/venv/bin/activate + pip install cupy-cuda11x cd $RUNNER_WORKSPACE/finufft/python/cufinufft python setup.py develop # FIXME: This is hardcoded diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index eac477d1..214d57c7 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -368,7 +368,8 @@ def _op_sense_host(self, data, ksp=None): coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) dataf = data.reshape((B, *XYZ)) data_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) - ksp = ksp or np.empty((B, C, K), dtype=self.cpx_dtype) + if ksp is None: + ksp = np.empty((B, C, K), dtype=self.cpx_dtype) ksp = ksp.reshape((B * C, K)) ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype) @@ -408,8 +409,9 @@ def _op_calibless_host(self, data, ksp=None): coil_img_d = cp.empty(np.prod(XYZ) * T, dtype=self.cpx_dtype) ksp_d = cp.empty((T, K), dtype=self.cpx_dtype) - - ksp = np.zeros((B * C, K), dtype=self.cpx_dtype) + if ksp is None: + ksp = np.zeros((B * C, K), dtype=self.cpx_dtype) + ksp = ksp.reshape((B * C, K)) # TODO: Add concurrency compute batch n while copying batch n+1 to device # and batch n-1 to host dataf = data.flatten() @@ -504,13 +506,13 @@ def _adj_op_sense_host(self, coeffs, img_d=None): # Define short name T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape + + coeffs_f = coeffs.flatten() # Allocate memory coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) if img_d is None: img_d = cp.zeros((B, *XYZ), dtype=self.cpx_dtype) - smaps_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) - coeffs_f = coeffs.flatten() ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype) if self.uses_density: density_batched = cp.repeat(self.density[None, :], T, axis=0) @@ -533,23 +535,16 @@ def _adj_op_sense_host(self, coeffs, img_d=None): return img def _adj_op_calibless_device(self, coeffs, img_d=None): + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape coeffs_f = coeffs.flatten() - n_trans_samples = self.n_trans * self.n_samples - ksp_batched = cp.empty(n_trans_samples, dtype=self.cpx_dtype) + ksp_batched = cp.empty(T * K, dtype=self.cpx_dtype) if self.uses_density: - density_batched = cp.repeat( - self.density[None, :], self.n_trans, axis=0 - ).flatten() - img_d = img_d or cp.empty( - (self.n_batchs, self.n_coils, *self.shape), - dtype=self.cpx_dtype, - ) - for i in range((self.n_coils * self.n_batchs) // self.n_trans): + density_batched = cp.repeat(self.density[None, :], T, axis=0).flatten() + img_d = img_d or cp.empty((B, C, *XYZ), dtype=self.cpx_dtype) + for i in range((B * C) // T): if self.uses_density: - cp.copyto( - ksp_batched, - coeffs_f[i * n_trans_samples : (i + 1) * n_trans_samples], - ) + cp.copyto(ksp_batched, coeffs_f[i * T * K : (i + 1) * T * K]) ksp_batched *= density_batched self.__adj_op(get_ptr(ksp_batched), get_ptr(img_d) + i * self.bsize_img) else: @@ -560,28 +555,25 @@ def _adj_op_calibless_device(self, coeffs, img_d=None): return img_d def _adj_op_calibless_host(self, coeffs, img_batched=None): + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape coeffs_f = coeffs.flatten() - n_trans_samples = self.n_trans * self.n_samples - ksp_batched = cp.empty(n_trans_samples, dtype=self.cpx_dtype) + ksp_batched = cp.empty(T * K, dtype=self.cpx_dtype) if self.uses_density: - density_batched = cp.repeat( - self.density[None, :], self.n_trans, axis=0 - ).flatten() + density_batched = cp.repeat(self.density[None, :], T, axis=0).flatten() - img = np.zeros( - (self.n_batchs * self.n_coils, *self.shape), dtype=self.cpx_dtype - ) + img = np.zeros((B * C, *XYZ), dtype=self.cpx_dtype) if img_batched is None: - img_batched = cp.empty((self.n_trans, *self.shape), dtype=self.cpx_dtype) + img_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) # TODO: Add concurrency compute batch n while copying batch n+1 to device # and batch n-1 to host - for i in range((self.n_batchs * self.n_coils) // self.n_trans): - ksp_batched.set(coeffs_f[i * n_trans_samples : (i + 1) * n_trans_samples]) + for i in range((B * C) // T): + ksp_batched.set(coeffs_f[i * T * K : (i + 1) * T * K]) if self.uses_density: ksp_batched *= density_batched self.__adj_op(get_ptr(ksp_batched), get_ptr(img_batched)) - img[i * self.n_trans : (i + 1) * self.n_trans] = img_batched.get() - img = img.reshape((self.n_batchs, self.n_coils, *self.shape)) + img[i * T : (i + 1) * T] = img_batched.get() + img = img.reshape((B, C, *XYZ)) return img @nvtx_mark() diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index 7bd498b2..f2c272a9 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -1,10 +1,24 @@ """Stacked Operator for NUFFT.""" +import warnings import numpy as np import scipy as sp -from .base import FourierOperatorBase, proper_trajectory -from . import get_operator +from .base import FourierOperatorBase, check_backend, get_operator, proper_trajectory +from .interfaces.utils import ( + is_cuda_array, + is_host_array, + pin_memory, + sizeof_fmt, +) + +CUPY_AVAILABLE = True +try: + import cupy as cp + from cupyx.scipy import fft as cpfft + +except ImportError: + CUPY_AVAILABLE = False class MRIStackedNUFFT(FourierOperatorBase): @@ -36,7 +50,7 @@ class MRIStackedNUFFT(FourierOperatorBase): # Internally the stacked NUFFT operator (self) uses a backend MRI aware NUFFT # operator(op), configured as such: # - op.smaps=None - # - op.n_coils = len(self.z_index) ; op.n_batchs = self.n_coils * self.n_batchs. + # - op.n_coils self.n_coils * len(self.z_index) ; op.n_batch= 1. # The kspace is organized as a 2D array of shape # (self.n_batchs, self.n_coils, self.n_samples) Note that the stack dimension is # fused with the samples @@ -57,8 +71,25 @@ def __init__( **kwargs, ): super().__init__() - + samples2d, z_index_ = self._init_samples(samples, z_index, shape) self.shape = shape + self.samples = samples2d.reshape(-1, 2) + self.z_index = z_index_ + self.n_coils = n_coils + self.n_batchs = n_batchs + self.squeeze_dims = squeeze_dims + self.smaps = smaps + self.operator = get_operator(backend)( + self.samples, + shape[:-1], + n_coils=self.n_coils * len(self.z_index), + smaps=None, + squeeze_dims=True, + **kwargs, + ) + + @staticmethod + def _init_samples(samples, z_index, shape): samples_dim = samples.shape[-1] auto_z = isinstance(z_index, str) and z_index == "auto" if samples_dim == len(shape) and auto_z: @@ -80,21 +111,7 @@ def __init__( ) from e else: raise ValueError("Invalid samples or z-index") - - self.samples = samples2d.reshape(-1, 2) - self.z_index = z_index_ - self.n_coils = n_coils - self.n_batchs = n_batchs - self.squeeze_dims = squeeze_dims - self.smaps = smaps - self.operator = get_operator(backend)( - self.samples, - shape[:-1], - n_coils=self.n_coils, - smaps=None, - squeeze_dims=squeeze_dims, - **kwargs, - ) + return samples2d, z_index_ @property def dtype(self): @@ -130,82 +147,82 @@ def op(self, data, ksp=None): def _op_sense(self, data, ksp=None): """Apply SENSE operator.""" - ksp = ksp or np.zeros( - ( - self.n_batchs, - self.n_coils, - len(self.z_index), - len(self.samples), - ), - dtype=self.cpx_dtype, - ) - data_ = data.reshape(self.n_batchs, *self.shape) - for b in range(self.n_batchs): + B, C, XYZ = self.n_batchs, self.n_coils, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + if ksp is None: + ksp = np.empty((B, C, NZ, NS), dtype=self.cpx_dtype) + ksp = ksp.reshape((B, C * NZ, NS)) + data_ = data.reshape(B, *XYZ) + for b in range(B): data_c = data_[b] * self.smaps - ksp_z = self._fftz(data_c) - ksp_z = ksp_z.reshape(self.n_coils, *self.shape) - for i, zidx in enumerate(self.z_index): - # TODO Both array slices yields non continuous views. - t = np.ascontiguousarray(ksp_z[..., zidx]) - ksp[b, ..., i, :] = self.operator.op(t) - ksp = ksp.reshape(self.n_batchs, self.n_coils, self.n_samples) + data_c = self._fftz(data_c) + data_c = data_c.reshape(C, *XYZ) + tmp = np.ascontiguousarray(data_c[..., self.z_index]) + tmp = np.moveaxis(tmp, -1, 1) + tmp = tmp.reshape(C * NZ, *XYZ[:2]) + ksp[b, ...] = self.operator.op(np.ascontiguousarray(tmp)) + ksp = ksp.reshape((B, C, NZ * NS)) return ksp def _op_calibless(self, data, ksp=None): + B, C, XYZ = self.n_batchs, self.n_coils, self.shape + NS, NZ = len(self.samples), len(self.z_index) if ksp is None: - ksp = np.empty( - (self.n_batchs, self.n_coils, len(self.z_index), len(self.samples)), - dtype=self.cpx_dtype, - ) - data_ = data.reshape((self.n_batchs, self.n_coils, *self.shape)) + ksp = np.empty((B, C, NZ, NS), dtype=self.cpx_dtype) + ksp = ksp.reshape((B, C * NZ, NS)) + data_ = data.reshape(B, C, *XYZ) ksp_z = self._fftz(data_) - ksp_z = ksp_z.reshape((self.n_batchs, self.n_coils, *self.shape)) - for b in range(self.n_batchs): - for i, zidx in enumerate(self.z_index): - t = np.ascontiguousarray(ksp_z[b, ..., zidx]) - ksp[b, ..., i, :] = self.operator.op(t) - ksp = ksp.reshape(self.n_batchs, self.n_coils, self.n_samples) + ksp_z = ksp_z.reshape((B, C, *XYZ)) + for b in range(B): + tmp = ksp_z[b][..., self.z_index] + tmp = np.moveaxis(tmp, -1, 1) + tmp = tmp.reshape(C * NZ, *XYZ[:2]) + ksp[b, ...] = self.operator.op(np.ascontiguousarray(tmp)) + ksp = ksp.reshape((B, C, NZ, NS)) + ksp = ksp.reshape((B, C, NZ * NS)) return ksp def adj_op(self, coeffs, img=None): """Adjoint operator.""" - coeffs_ = np.reshape( - coeffs, (self.n_batchs, self.n_coils, len(self.samples), len(self.z_index)) - ) if self.uses_sense: - return self._safe_squeeze(self._adj_op_sense(coeffs_, img)) - return self._safe_squeeze(self._adj_op_calibless(coeffs_, img)) + return self._safe_squeeze(self._adj_op_sense(coeffs, img)) + return self._safe_squeeze(self._adj_op_calibless(coeffs, img)) def _adj_op_sense(self, coeffs, img): - imgz = np.zeros( - (self.n_batchs, self.n_coils, *self.shape), dtype=self.cpx_dtype - ) - coeffs_ = coeffs.reshape( - (self.n_batchs, self.n_coils, len(self.z_index), len(self.samples)), - ) - for b in range(self.n_batchs): - for i, zidx in enumerate(self.z_index): - # TODO Both array slices yields non continuous views. - t = np.ascontiguousarray(coeffs_[b, ..., i, :]) - imgz[b, ..., zidx] = self.operator.adj_op(t) + B, C, XYZ = self.n_batchs, self.n_coils, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + imgz = np.zeros((B, C, *XYZ), dtype=self.cpx_dtype) + coeffs_ = coeffs.reshape((B, C * NZ, NS)) + for b in range(B): + tmp = np.ascontiguousarray(coeffs_[b, ...]) + tmp_adj = self.operator.adj_op(tmp) + # move the z axis back + tmp_adj = tmp_adj.reshape(C, NZ, *XYZ[:2]) + tmp_adj = np.moveaxis(tmp_adj, 1, -1) + imgz[b][..., self.z_index] = tmp_adj imgc = self._ifftz(imgz) - img = img or np.empty((self.n_batchs, *self.shape), dtype=self.cpx_dtype) - for b in range(self.n_batchs): + img = img or np.empty((B, *XYZ), dtype=self.cpx_dtype) + for b in range(B): img[b] = np.sum(imgc[b] * self.smaps.conj(), axis=0) return img def _adj_op_calibless(self, coeffs, img): - imgz = np.zeros( - (self.n_batchs, self.n_coils, *self.shape), dtype=self.cpx_dtype - ) - coeffs_ = coeffs.reshape( - (self.n_batchs, self.n_coils, len(self.z_index), len(self.samples)), - ) - for b in range(self.n_batchs): - for i, zidx in enumerate(self.z_index): - t = np.ascontiguousarray(coeffs_[b, ..., i, :]) - imgz[b, ..., zidx] = self.operator.adj_op(t) - imgz = np.reshape(imgz, (self.n_batchs, self.n_coils, *self.shape)) + B, C, XYZ = self.n_batchs, self.n_coils, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + imgz = np.zeros((B, C, *XYZ), dtype=self.cpx_dtype) + coeffs_ = coeffs.reshape((B, C, NZ, NS)) + coeffs_ = coeffs.reshape((B, C * NZ, NS)) + for b in range(B): + t = np.ascontiguousarray(coeffs_[b, ...]) + adj = self.operator.adj_op(t) + # move the z axis back + adj = adj.reshape(C, NZ, *XYZ[:2]) + adj = np.moveaxis(adj, 1, -1) + imgz[b][..., self.z_index] = np.ascontiguousarray(adj) + imgz = np.reshape(imgz, (B, C, *XYZ)) img = self._ifftz(imgz) return img @@ -223,6 +240,298 @@ def _safe_squeeze(self, arr): return arr +class MRIStackedNUFFTGPU(MRIStackedNUFFT): + """ + Stacked NUFFT Operator for MRI using GPU only backend. + + This requires cufinufft to be installed. + + Parameters + ---------- + samples : array-like + Sample locations in a 2D kspace + shape: tuple + Shape of the image. + z_index: array-like + Cartesian z index of masked plan. + smaps: array-like + Sensitivity maps. + n_coils: int + Number of coils. + n_batchs: int + Number of batchs. + **kwargs: dict + Additional arguments to pass to the backend. + """ + + backend = "stacked-cufinufft" + available = True # the true availabily will be check at runtime. + + def __init__( + self, + samples, + shape, + z_index, + smaps, + n_coils=1, + n_batchs=1, + n_trans=1, + squeeze_dims=False, + smaps_cached=False, + density=False, + **kwargs, + ): + if not (CUPY_AVAILABLE and check_backend("cufinufft")): + raise RuntimeError("Cupy and cufinufft are required for this backend.") + + if (n_batchs * n_coils) % n_trans != 0: + raise ValueError("n_batchs * n_coils should be a multiple of n_transf") + + samples2d, z_index_ = self._init_samples(samples, z_index, shape) + self.shape = shape + self.samples = samples2d.reshape(-1, 2) + self.z_index = z_index_ + self.n_coils = n_coils + self.n_batchs = n_batchs + self.n_trans = n_trans + self.squeeze_dims = squeeze_dims + + self.operator = get_operator("cufinufft")( + self.samples, + shape[:-1], + n_coils=n_trans * len(self.z_index), + n_trans=len(self.z_index), + smaps=None, + squeeze_dims=True, + density=density, + **kwargs, + ) + # Smaps support + self.smaps = smaps + self.smaps_cached = False + if smaps is not None: + if not (is_host_array(smaps) or is_cuda_array(smaps)): + raise ValueError( + "Smaps should be either a C-ordered ndarray, " "or a GPUArray." + ) + self.smaps_cached = False + if smaps_cached: + warnings.warn( + f"{sizeof_fmt(smaps.size * np.dtype(self.cpx_dtype).itemsize)}" + "used on gpu for smaps." + ) + self.smaps = cp.array( + smaps, order="C", copy=False, dtype=self.cpx_dtype + ) + self.smaps_cached = True + else: + self.smaps = pin_memory(smaps.astype(self.cpx_dtype)) + self._smap_d = cp.empty(self.shape, dtype=self.cpx_dtype) + + @property + def norm_factor(self): + """Norm factor of the operator.""" + return self.operator.norm_factor * np.sqrt(2) + + @staticmethod + def _fftz(data): + """Apply FFT on z-axis.""" + # sqrt(2) required for normalization + return cpfft.fftshift( + cpfft.fft( + cpfft.ifftshift(data, axes=-1), + axis=-1, + norm="ortho", + overwrite_x=True, + ), + axes=-1, + ) + + @staticmethod + def _ifftz(data): + """Apply IFFT on z-axis.""" + # sqrt(2) required for normalization + return cpfft.fftshift( + cpfft.ifft( + cpfft.ifftshift(data, axes=-1), + axis=-1, + norm="ortho", + overwrite_x=False, + ), + axes=-1, + ) + + def op(self, data, ksp=None): + """Forward operator.""" + # Dispatch to special case. + if self.uses_sense and is_cuda_array(data): + op_func = self._op_sense_device + elif self.uses_sense: + op_func = self._op_sense_host + elif is_cuda_array(data): + op_func = self._op_calibless_device + else: + op_func = self._op_calibless_host + ret = op_func(data, ksp) + + return self._safe_squeeze(ret) + + def _op_sense_host(self, data, ksp=None): + B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + dataf = data.reshape((B, *XYZ)) + coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + data_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + + if ksp is None: + ksp = np.empty((B, C, NZ, NS), dtype=self.cpx_dtype) + ksp = ksp.reshape((B * C, NZ * NS)) + ksp_batched = cp.empty((T * NZ, NS), dtype=self.cpx_dtype) + for i in range((B * C) // T): + idx_coils = np.arange(i * T, (i + 1) * T) % C + idx_batch = np.arange(i * T, (i + 1) * T) // C + # Send the n_trans coils to gpu + data_batched.set(dataf[idx_batch].reshape((T, *XYZ))) + # Apply Smaps + if not self.smaps_cached: + coil_img_d.set(self.smaps[idx_coils].reshape((T, *XYZ))) + else: + cp.copyto(coil_img_d, self.smaps[idx_coils]) + coil_img_d *= data_batched + # FFT along Z axis (last) + coil_img_d = self._fftz(coil_img_d) + coil_img_d = coil_img_d.reshape((T, *XYZ)) + tmp = coil_img_d[..., self.z_index] + tmp = cp.moveaxis(tmp, -1, 1) + tmp = tmp.reshape(T * NZ, *XYZ[:2]) + # After reordering, apply 2D NUFFT + ksp_batched = self.operator._op_calibless_device(cp.ascontiguousarray(tmp)) + ksp_batched /= self.norm_factor + ksp_batched = ksp_batched.reshape(T, NZ, NS) + ksp_batched = ksp_batched.reshape(T, NZ * NS) + ksp[i * T : (i + 1) * T] = ksp_batched.get() + ksp = ksp.reshape((B, C, NZ * NS)) + return ksp + + def _op_sense_device(self, data, ksp): + raise NotImplementedError + + def _op_calibless_host(self, data, ksp=None): + B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + ksp_batched = cp.empty((T, NZ * NS), dtype=self.dtype) + if ksp is None: + ksp = np.zeros((B, C, NZ, NS), dtype=self.cpx_dtype) + ksp = ksp.reshape((B * C, NZ * NS)) + + dataf = data.reshape(B * C, *XYZ) + + for i in range((B * C) // T): + coil_img_d.set(dataf[i * T : (i + 1) * T]) + coil_img_d = self._fftz(coil_img_d) + coil_img_d = coil_img_d.reshape((T, *XYZ)) + tmp = coil_img_d[..., self.z_index] + tmp = cp.moveaxis(tmp, -1, 1) + tmp = tmp.reshape(T * NZ, *XYZ[:2]) + # After reordering, apply 2D NUFFT + ksp_batched = self.operator._op_calibless_device(cp.ascontiguousarray(tmp)) + ksp_batched /= self.norm_factor + ksp_batched = ksp_batched.reshape(T, NZ, NS) + ksp_batched = ksp_batched.reshape(T, NZ * NS) + ksp[i * T : (i + 1) * T] = ksp_batched.get() + + ksp = ksp.reshape((B, C, NZ * NS)) + return ksp + + def _op_calibless_device(self, data, ksp=None): + raise NotImplementedError + + def adj_op(self, coeffs, img=None): + """Adjoint operator.""" + # Dispatch to special case. + if self.uses_sense and is_cuda_array(coeffs): + adj_op_func = self._adj_op_sense_device + elif self.uses_sense: + adj_op_func = self._adj_op_sense_host + elif is_cuda_array(coeffs): + adj_op_func = self._adj_op_calibless_device + else: + adj_op_func = self._adj_op_calibless_host + + ret = adj_op_func(coeffs, img) + + return self._safe_squeeze(ret) + + def _adj_op_sense_host(self, coeffs, img_d=None): + B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + coeffs_f = coeffs.reshape(B * C, NZ * NS) + # Allocate Memory + coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + if img_d is None: + img_d = cp.zeros((B, *XYZ), dtype=self.cpx_dtype) + smaps_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + ksp_batched = cp.empty((T, NS * NZ), dtype=self.cpx_dtype) + + for i in range((B * C) // T): + idx_coils = np.arange(i * T, (i + 1) * T) % C + idx_batch = np.arange(i * T, (i + 1) * T) // C + if not self.smaps_cached: + smaps_batched.set(self.smaps[idx_coils]) + else: + smaps_batched = self.smaps[idx_coils] + ksp_batched.set(coeffs_f[i * T : (i + 1) * T]) + + tmp_adj = self.operator._adj_op_calibless_device(ksp_batched) + tmp_adj /= self.norm_factor + tmp_adj = tmp_adj.reshape((T, NZ, *XYZ[:2])) + tmp_adj = cp.moveaxis(tmp_adj, 1, -1) + coil_img_d[:] = 0j + coil_img_d[..., self.z_index] = tmp_adj + coil_img_d = self._ifftz(coil_img_d) + + for t, b in enumerate(idx_batch): + img_d[b, :] += coil_img_d[t] * smaps_batched[t].conj() + img = img_d.get() + img = img.reshape((B, 1, *XYZ)) + return img + + def _adj_op_sense_device(self, coeffs, img): + raise NotImplementedError + + def _adj_op_calibless_host(self, coeffs, img=None): + B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape + NS, NZ = len(self.samples), len(self.z_index) + coeffs_f = coeffs.reshape(B, C, NZ * NS) + coeffs_f = coeffs_f.reshape(B * C, NZ * NS) + # Allocate Memory + ksp_batched = cp.empty((T, NZ * NS), dtype=self.cpx_dtype) + img = np.zeros((B * C, *XYZ), dtype=self.cpx_dtype) + coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + for i in range((B * C) // T): + ksp_batched = ksp_batched.reshape(T, NZ * NS) + ksp_batched.set(coeffs_f[i * T : (i + 1) * T]) + ksp_batched = ksp_batched.reshape(T, NZ, NS) + ksp_batched = ksp_batched.reshape(T * NZ, NS) + tmp_adj = self.operator._adj_op_calibless_device(ksp_batched) + tmp_adj /= self.norm_factor + tmp_adj = tmp_adj.reshape((T, NZ, *XYZ[:2])) + tmp_adj = cp.moveaxis(tmp_adj, 1, -1) + coil_img_d[:] = 0j + coil_img_d[..., self.z_index] = tmp_adj + coil_img_d = self._ifftz(coil_img_d) + img[i * T : (i + 1) * T, ...] = coil_img_d.get() + img = img.reshape(B, C, *XYZ) + return img + + def _adj_op_calibless_device(self, coeffs, img): + raise NotImplementedError + + def traj3d2stacked(samples, dim_z, n_samples=0): """Convert a 3D trajectory into a trajectory and the z-stack index. diff --git a/tests/test_stacked_gpu.py b/tests/test_stacked_gpu.py new file mode 100644 index 00000000..6d57c437 --- /dev/null +++ b/tests/test_stacked_gpu.py @@ -0,0 +1,149 @@ +"""Test for the stacked NUFFT operator. + +The tests compares the stacked NUFFT (which uses FFT in the z-direction) +and the fully 3D ones. +""" +import numpy as np + +import numpy.testing as npt +from pytest_cases import parametrize_with_cases, parametrize, fixture + +from mrinufft.operators.stacked import ( + MRIStackedNUFFTGPU, + stacked2traj3d, + traj3d2stacked, +) +from mrinufft import get_operator +from case_trajectories import CasesTrajectories + + +@fixture(scope="module") +@parametrize( + "n_batchs, n_coils, sense", + [(1, 1, False), (1, 4, False), (1, 4, True), (3, 4, False), (3, 4, True)], +) +@parametrize("z_index", ["full", "random_mask"]) +@parametrize("backend", ["cufinufft"]) +@parametrize_with_cases("kspace_locs, shape", cases=CasesTrajectories, glob="*2D") +def operator(request, backend, kspace_locs, shape, z_index, n_batchs, n_coils, sense): + """Initialize the stacked and non-stacked operators.""" + shape3d = (*shape, shape[-1] - 2) # add a 3rd dimension + + if z_index == "full": + z_index = np.arange(shape3d[-1]) + z_index_ = z_index + elif z_index == "random_mask": + z_index = np.random.rand(shape3d[-1]) > 0.5 + z_index_ = np.arange(shape3d[-1])[z_index] + + kspace_locs3d = stacked2traj3d(kspace_locs, z_index_, shape3d[-1]) + # smaps support + if sense: + smaps = 1j * np.random.rand(n_coils, *shape3d) + smaps += np.random.rand(n_coils, *shape3d) + else: + smaps = None + + # Setup the operators + ref = get_operator(backend)( + kspace_locs3d, + shape=shape3d, + n_coils=n_coils, + n_batchs=n_batchs, + smaps=smaps, + ) + + stacked = MRIStackedNUFFTGPU( + samples=kspace_locs, + shape=shape3d, + z_index=z_index, + n_coils=n_coils, + n_trans=2 if n_coils > 1 else 1, + n_batchs=n_batchs, + smaps=smaps, + ) + return stacked, ref + + +@fixture(scope="module") +def stacked_op(operator): + """Return operator.""" + return operator[0] + + +@fixture(scope="module") +def ref_op(operator): + """Return ref operator.""" + return operator[1] + + +@fixture(scope="module") +def image_data(stacked_op): + """Generate a random image.""" + B, C = stacked_op.n_batchs, stacked_op.n_coils + if stacked_op.smaps is None: + img = np.random.randn(B, C, *stacked_op.shape).astype(stacked_op.cpx_dtype) + elif stacked_op.smaps is not None and stacked_op.n_coils > 1: + img = np.random.randn(B, *stacked_op.shape).astype(stacked_op.cpx_dtype) + + img += 1j * np.random.randn(*img.shape).astype(stacked_op.cpx_dtype) + return img + + +@fixture(scope="module") +def kspace_data(stacked_op): + """Generate a random kspace data.""" + B, C = stacked_op.n_batchs, stacked_op.n_coils + kspace = (1j * np.random.randn(B, C, stacked_op.n_samples)).astype( + stacked_op.cpx_dtype + ) + kspace += np.random.randn(B, C, stacked_op.n_samples).astype(stacked_op.cpx_dtype) + return kspace + + +def test_stack_forward(operator, stacked_op, ref_op, image_data): + """Compare the stack interface to the 3D NUFFT implementation.""" + kspace_nufft = stacked_op.op(image_data).squeeze() + kspace_ref = ref_op.op(image_data).squeeze() + npt.assert_allclose(kspace_nufft, kspace_ref, atol=1e-4, rtol=1e-1) + + +def test_stack_backward(operator, stacked_op, ref_op, kspace_data): + """Compare the stack interface to the 3D NUFFT implementation.""" + image_nufft = stacked_op.adj_op(kspace_data.copy()).squeeze() + image_ref = ref_op.adj_op(kspace_data.copy()).squeeze() + if stacked_op.n_coils > 1: + print( + np.max( + np.abs(image_nufft - image_ref).reshape( + stacked_op.n_batchs, stacked_op.n_coils, -1 + ), + axis=-1, + ) + ) + print(image_nufft.shape, image_ref.shape) + npt.assert_allclose(image_nufft, image_ref, atol=1e-4, rtol=1e-1) + + +def test_stack_auto_adjoint(operator, stacked_op, kspace_data, image_data): + """Test the adjoint property of the stacked NUFFT operator.""" + kspace = stacked_op.op(image_data) + image = stacked_op.adj_op(kspace_data) + leftadjoint = np.vdot(image, image_data) + rightadjoint = np.vdot(kspace, kspace_data) + + npt.assert_allclose(leftadjoint.conj(), rightadjoint, atol=1e-4, rtol=1e-4) + + +def test_stacked2traj3d(): + """Test the conversion from stacked to 3d trajectory.""" + dimz = 64 + traj2d = np.random.randn(100, 2) + z_index = np.random.choice(dimz, 20, replace=False) + + traj3d = stacked2traj3d(traj2d, z_index, dimz) + + traj2d, z_index2 = traj3d2stacked(traj3d, dimz) + + npt.assert_allclose(traj2d, traj2d) + npt.assert_allclose(z_index, z_index2)