From d95250f8d8082f4037eda34265b614e54a594b45 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Fri, 13 Oct 2023 15:59:57 +0200 Subject: [PATCH 01/15] feat: use a very large coil dimension. --- src/mrinufft/operators/stacked.py | 120 +++++++++++++++--------------- 1 file changed, 60 insertions(+), 60 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index 7bd498b2..bfa17b64 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -36,7 +36,7 @@ class MRIStackedNUFFT(FourierOperatorBase): # Internally the stacked NUFFT operator (self) uses a backend MRI aware NUFFT # operator(op), configured as such: # - op.smaps=None - # - op.n_coils = len(self.z_index) ; op.n_batchs = self.n_coils * self.n_batchs. + # - op.n_coils self.n_coils * len(self.z_index) ; op.n_batch= 1. # The kspace is organized as a 2D array of shape # (self.n_batchs, self.n_coils, self.n_samples) Note that the stack dimension is # fused with the samples @@ -90,9 +90,9 @@ def __init__( self.operator = get_operator(backend)( self.samples, shape[:-1], - n_coils=self.n_coils, + n_coils=self.n_coils * len(self.z_index), smaps=None, - squeeze_dims=squeeze_dims, + squeeze_dims=True, **kwargs, ) @@ -130,82 +130,82 @@ def op(self, data, ksp=None): def _op_sense(self, data, ksp=None): """Apply SENSE operator.""" - ksp = ksp or np.zeros( - ( - self.n_batchs, - self.n_coils, - len(self.z_index), - len(self.samples), - ), - dtype=self.cpx_dtype, - ) - data_ = data.reshape(self.n_batchs, *self.shape) - for b in range(self.n_batchs): + B, C, XYZ = self.n_batchs, self.n_coils, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + if ksp is None: + ksp = np.empty((B, C, NZ, NS), dtype=self.cpx_dtype) + ksp = ksp.reshape((B, C * NZ, NS)) + data_ = data.reshape(B, *XYZ) + for b in range(B): data_c = data_[b] * self.smaps ksp_z = self._fftz(data_c) - ksp_z = ksp_z.reshape(self.n_coils, *self.shape) - for i, zidx in enumerate(self.z_index): - # TODO Both array slices yields non continuous views. - t = np.ascontiguousarray(ksp_z[..., zidx]) - ksp[b, ..., i, :] = self.operator.op(t) - ksp = ksp.reshape(self.n_batchs, self.n_coils, self.n_samples) + ksp_z = ksp_z.reshape(C, *XYZ) + tmp = np.ascontiguousarray(ksp_z[..., self.z_index]) + tmp = np.moveaxis(tmp, -1, 1) + tmp = tmp.reshape(C * NZ, *XYZ[:2]) + ksp[b, ...] = self.operator.op(np.ascontiguousarray(tmp)) + ksp = ksp.reshape((B, C, NZ * NS)) return ksp def _op_calibless(self, data, ksp=None): + B, C, XYZ = self.n_batchs, self.n_coils, self.shape + NS, NZ = len(self.samples), len(self.z_index) if ksp is None: - ksp = np.empty( - (self.n_batchs, self.n_coils, len(self.z_index), len(self.samples)), - dtype=self.cpx_dtype, - ) - data_ = data.reshape((self.n_batchs, self.n_coils, *self.shape)) + ksp = np.empty((B, C, NZ, NS), dtype=self.cpx_dtype) + ksp = ksp.reshape((B, C * NZ, NS)) + data_ = data.reshape(B, C, *XYZ) ksp_z = self._fftz(data_) - ksp_z = ksp_z.reshape((self.n_batchs, self.n_coils, *self.shape)) - for b in range(self.n_batchs): - for i, zidx in enumerate(self.z_index): - t = np.ascontiguousarray(ksp_z[b, ..., zidx]) - ksp[b, ..., i, :] = self.operator.op(t) - ksp = ksp.reshape(self.n_batchs, self.n_coils, self.n_samples) + ksp_z = ksp_z.reshape((B, C, *XYZ)) + for b in range(B): + tmp = ksp_z[b][..., self.z_index] + tmp = np.moveaxis(tmp, -1, 1) + tmp = tmp.reshape(C * NZ, *XYZ[:2]) + ksp[b, ...] = self.operator.op(np.ascontiguousarray(tmp)) + ksp = ksp.reshape((B, C, NZ, NS)) + ksp = ksp.reshape((B, C, NZ * NS)) return ksp def adj_op(self, coeffs, img=None): """Adjoint operator.""" - coeffs_ = np.reshape( - coeffs, (self.n_batchs, self.n_coils, len(self.samples), len(self.z_index)) - ) if self.uses_sense: - return self._safe_squeeze(self._adj_op_sense(coeffs_, img)) - return self._safe_squeeze(self._adj_op_calibless(coeffs_, img)) + return self._safe_squeeze(self._adj_op_sense(coeffs, img)) + return self._safe_squeeze(self._adj_op_calibless(coeffs, img)) def _adj_op_sense(self, coeffs, img): - imgz = np.zeros( - (self.n_batchs, self.n_coils, *self.shape), dtype=self.cpx_dtype - ) - coeffs_ = coeffs.reshape( - (self.n_batchs, self.n_coils, len(self.z_index), len(self.samples)), - ) - for b in range(self.n_batchs): - for i, zidx in enumerate(self.z_index): - # TODO Both array slices yields non continuous views. - t = np.ascontiguousarray(coeffs_[b, ..., i, :]) - imgz[b, ..., zidx] = self.operator.adj_op(t) + B, C, XYZ = self.n_batchs, self.n_coils, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + imgz = np.zeros((B, C, *XYZ), dtype=self.cpx_dtype) + coeffs_ = coeffs.reshape((B, C * NZ, NS)) + for b in range(B): + tmp = np.ascontiguousarray(coeffs_[b, ...]) + tmp_adj = self.operator.adj_op(tmp) + # move the z axis back + tmp_adj = tmp_adj.reshape(C, NZ, *XYZ[:2]) + tmp_adj = np.moveaxis(tmp_adj, 1, -1) + imgz[b][..., self.z_index] = tmp_adj imgc = self._ifftz(imgz) - img = img or np.empty((self.n_batchs, *self.shape), dtype=self.cpx_dtype) - for b in range(self.n_batchs): + img = img or np.empty((B, *XYZ), dtype=self.cpx_dtype) + for b in range(B): img[b] = np.sum(imgc[b] * self.smaps.conj(), axis=0) return img def _adj_op_calibless(self, coeffs, img): - imgz = np.zeros( - (self.n_batchs, self.n_coils, *self.shape), dtype=self.cpx_dtype - ) - coeffs_ = coeffs.reshape( - (self.n_batchs, self.n_coils, len(self.z_index), len(self.samples)), - ) - for b in range(self.n_batchs): - for i, zidx in enumerate(self.z_index): - t = np.ascontiguousarray(coeffs_[b, ..., i, :]) - imgz[b, ..., zidx] = self.operator.adj_op(t) - imgz = np.reshape(imgz, (self.n_batchs, self.n_coils, *self.shape)) + B, C, XYZ = self.n_batchs, self.n_coils, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + imgz = np.zeros((B, C, *XYZ), dtype=self.cpx_dtype) + coeffs_ = coeffs.reshape((B, C, NZ, NS)) + coeffs_ = coeffs.reshape((B, C * NZ, NS)) + for b in range(B): + t = np.ascontiguousarray(coeffs_[b, ...]) + adj = self.operator.adj_op(t) + # move the z axis back + adj = adj.reshape(C, NZ, *XYZ[:2]) + adj = np.moveaxis(adj, 1, -1) + imgz[b][..., self.z_index] = np.ascontiguousarray(adj) + imgz = np.reshape(imgz, (B, C, *XYZ)) img = self._ifftz(imgz) return img From 6d2504d0530131ff2ab240a5616bedc91a59b346 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 16 Oct 2023 14:09:19 +0200 Subject: [PATCH 02/15] refactor(cufinufft): use same conventions across _adj_op functions. --- .../operators/interfaces/cufinufft.py | 56 ++++++++----------- 1 file changed, 24 insertions(+), 32 deletions(-) diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index eac477d1..690054ab 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -368,7 +368,8 @@ def _op_sense_host(self, data, ksp=None): coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) dataf = data.reshape((B, *XYZ)) data_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) - ksp = ksp or np.empty((B, C, K), dtype=self.cpx_dtype) + if ksp is None: + ksp = np.empty((B, C, K), dtype=self.cpx_dtype) ksp = ksp.reshape((B * C, K)) ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype) @@ -408,8 +409,9 @@ def _op_calibless_host(self, data, ksp=None): coil_img_d = cp.empty(np.prod(XYZ) * T, dtype=self.cpx_dtype) ksp_d = cp.empty((T, K), dtype=self.cpx_dtype) - - ksp = np.zeros((B * C, K), dtype=self.cpx_dtype) + if ksp is None: + ksp = np.zeros((B * C, K), dtype=self.cpx_dtype) + ksp = ksp.reshape((B * C, K)) # TODO: Add concurrency compute batch n while copying batch n+1 to device # and batch n-1 to host dataf = data.flatten() @@ -504,13 +506,13 @@ def _adj_op_sense_host(self, coeffs, img_d=None): # Define short name T, B, C = self.n_trans, self.n_batchs, self.n_coils K, XYZ = self.n_samples, self.shape + + coeffs_f = coeffs.flatten() # Allocate memory coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) if img_d is None: img_d = cp.zeros((B, *XYZ), dtype=self.cpx_dtype) - smaps_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) - coeffs_f = coeffs.flatten() ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype) if self.uses_density: density_batched = cp.repeat(self.density[None, :], T, axis=0) @@ -533,23 +535,16 @@ def _adj_op_sense_host(self, coeffs, img_d=None): return img def _adj_op_calibless_device(self, coeffs, img_d=None): + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape coeffs_f = coeffs.flatten() - n_trans_samples = self.n_trans * self.n_samples - ksp_batched = cp.empty(n_trans_samples, dtype=self.cpx_dtype) + ksp_batched = cp.empty(T * K, dtype=self.cpx_dtype) if self.uses_density: - density_batched = cp.repeat( - self.density[None, :], self.n_trans, axis=0 - ).flatten() - img_d = img_d or cp.empty( - (self.n_batchs, self.n_coils, *self.shape), - dtype=self.cpx_dtype, - ) - for i in range((self.n_coils * self.n_batchs) // self.n_trans): + density_batched = cp.repeat(self.density[None, :], T, axis=0).flatten() + img_d = img_d or cp.empty((B, C, *XYZ), dtype=self.cpx_dtype) + for i in range((B * C) // T): if self.uses_density: - cp.copyto( - ksp_batched, - coeffs_f[i * n_trans_samples : (i + 1) * n_trans_samples], - ) + cp.copyto(ksp_batched, coeffs_f[i * T * K : (i + 1) * T * K]) ksp_batched *= density_batched self.__adj_op(get_ptr(ksp_batched), get_ptr(img_d) + i * self.bsize_img) else: @@ -560,28 +555,25 @@ def _adj_op_calibless_device(self, coeffs, img_d=None): return img_d def _adj_op_calibless_host(self, coeffs, img_batched=None): + T, B, C = self.n_trans, self.n_batchs, self.n_coils + K, XYZ = self.n_samples, self.shape coeffs_f = coeffs.flatten() - n_trans_samples = self.n_trans * self.n_samples - ksp_batched = cp.empty(n_trans_samples, dtype=self.cpx_dtype) + ksp_batched = cp.empty(T * K, dtype=self.cpx_dtype) if self.uses_density: - density_batched = cp.repeat( - self.density[None, :], self.n_trans, axis=0 - ).flatten() + density_batched = cp.repeat(self.density[None, :], T, axis=0).flatten() - img = np.zeros( - (self.n_batchs * self.n_coils, *self.shape), dtype=self.cpx_dtype - ) + img = np.zeros((B, C, *XYZ), dtype=self.cpx_dtype) if img_batched is None: - img_batched = cp.empty((self.n_trans, *self.shape), dtype=self.cpx_dtype) + img_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) # TODO: Add concurrency compute batch n while copying batch n+1 to device # and batch n-1 to host - for i in range((self.n_batchs * self.n_coils) // self.n_trans): - ksp_batched.set(coeffs_f[i * n_trans_samples : (i + 1) * n_trans_samples]) + for i in range((B * C) // T): + ksp_batched.set(coeffs_f[i * T * K : (i + 1) * T * K]) if self.uses_density: ksp_batched *= density_batched self.__adj_op(get_ptr(ksp_batched), get_ptr(img_batched)) - img[i * self.n_trans : (i + 1) * self.n_trans] = img_batched.get() - img = img.reshape((self.n_batchs, self.n_coils, *self.shape)) + img[i * T : (i + 1) * T] = img_batched.get() + img = img.reshape((B, C, *XYZ)) return img @nvtx_mark() From 8816d226ca9ae8379d978891438f37d7fe824d7b Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 16 Oct 2023 14:10:35 +0200 Subject: [PATCH 03/15] refactor(stacked): extract the samples preprocessing to dedicated method. --- src/mrinufft/operators/stacked.py | 35 +++++++++++++++++-------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index bfa17b64..3613ff56 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -57,8 +57,25 @@ def __init__( **kwargs, ): super().__init__() - + samples2d, z_index_ = self._init_samples(samples, z_index, shape) self.shape = shape + self.samples = samples2d.reshape(-1, 2) + self.z_index = z_index_ + self.n_coils = n_coils + self.n_batchs = n_batchs + self.squeeze_dims = squeeze_dims + self.smaps = smaps + self.operator = get_operator(backend)( + self.samples, + shape[:-1], + n_coils=self.n_coils * len(self.z_index), + smaps=None, + squeeze_dims=True, + **kwargs, + ) + + @staticmethod + def _init_samples(samples, z_index, shape): samples_dim = samples.shape[-1] auto_z = isinstance(z_index, str) and z_index == "auto" if samples_dim == len(shape) and auto_z: @@ -80,21 +97,7 @@ def __init__( ) from e else: raise ValueError("Invalid samples or z-index") - - self.samples = samples2d.reshape(-1, 2) - self.z_index = z_index_ - self.n_coils = n_coils - self.n_batchs = n_batchs - self.squeeze_dims = squeeze_dims - self.smaps = smaps - self.operator = get_operator(backend)( - self.samples, - shape[:-1], - n_coils=self.n_coils * len(self.z_index), - smaps=None, - squeeze_dims=True, - **kwargs, - ) + return samples2d, z_index_ @property def dtype(self): From 1cdbdb729738d84bca4b272e5d77b711a8a5c038 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 16 Oct 2023 14:11:39 +0200 Subject: [PATCH 04/15] feat(stacked): add GPU backend for stacked nufft. --- src/mrinufft/operators/stacked.py | 327 +++++++++++++++++++++++++++++- 1 file changed, 323 insertions(+), 4 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index 3613ff56..91babd20 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -1,11 +1,20 @@ """Stacked Operator for NUFFT.""" +import warnings import numpy as np import scipy as sp -from .base import FourierOperatorBase, proper_trajectory +from .base import FourierOperatorBase, proper_trajectory, check_backend from . import get_operator +CUPY_AVAILABLE = True +try: + import cupy as cp + from cupyx.scipy import fft as cpfft + from .interfaces.utils import is_cuda_array, is_host_array, pin_memory, sizeof_fmt +except ImportError: + CUPY_AVAILABLE = False + class MRIStackedNUFFT(FourierOperatorBase): """Stacked NUFFT Operator for MRI. @@ -142,9 +151,9 @@ def _op_sense(self, data, ksp=None): data_ = data.reshape(B, *XYZ) for b in range(B): data_c = data_[b] * self.smaps - ksp_z = self._fftz(data_c) - ksp_z = ksp_z.reshape(C, *XYZ) - tmp = np.ascontiguousarray(ksp_z[..., self.z_index]) + data_c = self._fftz(data_c) + data_c = data_c.reshape(C, *XYZ) + tmp = np.ascontiguousarray(data_c[..., self.z_index]) tmp = np.moveaxis(tmp, -1, 1) tmp = tmp.reshape(C * NZ, *XYZ[:2]) ksp[b, ...] = self.operator.op(np.ascontiguousarray(tmp)) @@ -226,6 +235,316 @@ def _safe_squeeze(self, arr): return arr +class MRIStackedNUFFTGPU(MRIStackedNUFFT): + """ + Stacked NUFFT Operator for MRI using GPU only backend. + + This requires cufinufft to be installed. + + Parameters + ---------- + samples : array-like + Sample locations in a 2D kspace + shape: tuple + Shape of the image. + z_index: array-like + Cartesian z index of masked plan. + smaps: array-like + Sensitivity maps. + n_coils: int + Number of coils. + n_batchs: int + Number of batchs. + **kwargs: dict + Additional arguments to pass to the backend. + """ + + backend = "stacked-cufinufft" + available = True # the true availabily will be check at runtime. + + def __init__( + self, + samples, + shape, + z_index, + smaps, + n_coils=1, + n_batchs=1, + n_trans=1, + squeeze_dims=False, + smaps_cached=False, + cufi_kwargs=None, + ): + if not (CUPY_AVAILABLE and check_backend("cufinufft")): + raise RuntimeError("Cupy and cufinufft are required for this backend.") + super().__init__() + + if (n_batchs * n_coils) % n_trans != 0: + raise ValueError("n_batchs * n_coils should be a multiple of n_transf") + + samples2d, z_index_ = self._init_samples(samples, z_index, shape) + self.shape = shape + self.samples = samples2d.reshape(-1, 2) + self.z_index = z_index_ + self.n_coils = n_coils + self.n_batchs = n_batchs + self.n_trans = n_trans + self.squeeze_dims = squeeze_dims + # Smaps support + self.smaps = smaps + self.smaps_cached = False + if smaps is not None: + if not (is_host_array(smaps) or is_cuda_array(smaps)): + raise ValueError( + "Smaps should be either a C-ordered ndarray, " "or a GPUArray." + ) + self.smaps_cached = False + if smaps_cached: + warnings.warn( + f"{sizeof_fmt(smaps.size * np.dtype(self.cpx_dtype).itemsize)}" + "used on gpu for smaps." + ) + self.smaps = cp.array( + smaps, order="C", copy=False, dtype=self.cpx_dtype + ) + self.smaps_cached = True + else: + self.smaps = pin_memory(smaps.astype(self.cpx_dtype)) + self._smap_d = cp.empty(self.shape, dtype=self.cpx_dtype) + + if cufi_kwargs is None: + cufi_kwargs = {} + self.operator = get_operator("cufinufft")( + self.samples, + shape[:-1], + n_coils=n_trans * len(self.z_index), + smaps=None, + squeeze_dims=True, + **cufi_kwargs, + ) + + @staticmethod + def _fftz(data): + """Apply FFT on z-axis.""" + # sqrt(2) required for normalization + return cpfft.fftshift( + cpfft.fft( + cpfft.ifftshift(data, axes=-1), axis=-1, norm="ortho", overwrite_x=True + ), + axes=-1, + ) / np.sqrt(2) + + @staticmethod + def _ifftz(data): + """Apply IFFT on z-axis.""" + # sqrt(2) required for normalization + return cpfft.fftshift( + cpfft.ifft( + cpfft.ifftshift(data, axes=-1), axis=-1, norm="ortho", overwrite_x=True + ), + axes=-1, + ) / np.sqrt(2) + + def op(self, data, ksp=None): + """Forward operator.""" + # Dispatch to special case. + if self.uses_sense and is_cuda_array(data): + op_func = self._op_sense_device + elif self.uses_sense: + op_func = self._op_sense_host + elif is_cuda_array(data): + op_func = self._op_calibless_device + else: + op_func = self._op_calibless_host + ret = op_func(data, ksp) + + return self._safe_squeeze(ret) + + def _op_sense_host(self, data, ksp): + B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + dataf = data.reshape((B, *XYZ)) + coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + data_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + + if ksp is None: + ksp = np.empty((B, C, NZ * NS), dtype=self.cpx_dtype) + ksp = ksp.reshape((B * C, NZ * NS)) + ksp_batched = cp.empty((T, NZ * NS), dtype=self.cpx_dtype) + for i in range(B * C // T): + idx_coils = np.arange(i * T, (i + 1) * T) % C + idx_batch = np.arange(i * T, (i + 1) * T) // C + # Send the n_trans coils to gpu + data_batched.set(dataf[idx_batch].reshape((T, *XYZ))) + # Apply Smaps + if not self.smaps_cached: + coil_img_d.set(self.smaps[idx_coils].reshape((T, *XYZ))) + else: + cp.copyto(coil_img_d, self.smaps[idx_coils]) + coil_img_d *= data_batched + # FFT along Z axis (last) + coil_img_d = self.fftz(coil_img_d) + coil_img_d = coil_img_d.reshape((T, *XYZ)) + tmp = coil_img_d[..., self.z_index] + tmp = cp.moveaxis(tmp, -1, 1) + tmp = tmp.reshape(T * NZ, *XYZ[:2]) + # After reordering, apply 2D NUFFT + ksp_batched = self.operator.op(cp.ascontiguousarray(tmp)) + ksp[i * T : (i + 1) * T] = ksp_batched.get() + ksp = ksp.reshape((B, C, NZ * NS)) + ksp = ksp.reshape((B, C, NZ, NS)) + return ksp + + def _op_sense_device(self, data, ksp): + raise NotImplementedError + + def _op_calibless_host(self, data, ksp=None): + B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + coil_img_d = cp.empty(T, *XYZ, dtype=self.cpx_dtype) + ksp_batched = cp.empty(T, NZ * NS, dtype=self.dtype) + if ksp is None: + ksp = np.zeros((B, C, NZ * NS), dtype=self.cpx_dtype) + ksp = ksp.reshape((B * C, NZ * NS)) + + dataf = data.reshape(B * C, *XYZ) + + for i in range((B * C) // T): + coil_img_d.set(dataf[i * T : (i + 1) * T]) + coil_img_d = self.fftz(coil_img_d) + coil_img_d = coil_img_d.reshape((T, *XYZ)) + tmp = coil_img_d[..., self.z_index] + tmp = cp.moveaxis(tmp, -1, 1) + tmp = tmp.reshape(T * NZ, *XYZ[:2]) + # After reordering, apply 2D NUFFT + ksp_batched = self.operator.op(cp.ascontiguousarray(tmp)) + ksp[i * T : (i + 1) * T] = ksp_batched.get() + + ksp = ksp.reshape((B, C, NZ * NS)) + ksp = ksp.reshape((B, C, NZ, NS)) + return ksp + + def _op_calibless_device(self, data, ksp=None): + raise NotImplementedError + + def adj_op(self, coeffs, img=None): + """Adjoint operator.""" + # Dispatch to special case. + if self.uses_sense and is_cuda_array(coeffs): + adj_op_func = self._adj_op_sense_device + elif self.uses_sense: + adj_op_func = self._adj_op_sense_host + elif is_cuda_array(coeffs): + adj_op_func = self._adj_op_calibless_device + else: + adj_op_func = self._adj_op_calibless_host + + ret = adj_op_func(coeffs, img) + + return self._safe_squeeze(ret) + + def _adj_op_sense_host(self, coeffs, img): + B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + coeffs_f = coeffs.reshape(B * C, NZ * NS) + # Allocate Memory + coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + if img_d is None: + img_d = cp.zeros((B, *XYZ), dtype=self.cpx_dtype) + smaps_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype) + + for i in range(B * C // T): + idx_coils = np.arange(i * T, (i + 1) * T) % C + idx_batch = np.arange(i * T, (i + 1) * T) // C + if not self.smaps_cached: + smaps_batched.set(self.smaps[idx_coils]) + else: + smaps_batched = self.smaps[idx_coils] + ksp_batched.set(coeffs_f[i * T : (i + 1) * T]) + + tmp_adj = self.operator.adj_op(ksp_batched) + tmp_adj = tmp_adj.reshape((T, NZ, *XYZ[:2])) + tmp_adj = cp.moveaxis(tmp_adj, 1, -1) + coil_img_d = cp.zeros_like(coil_img_d) + coil_img_d[..., self.z_index] = tmp_adj + coil_img_d = self.ifftz(coil_img_d) + + for t, b in enumerate(idx_batch): + img_d[b, :] += coil_img_d[t] * smaps_batched[t].conj() + img = img_d.get() + img = img.reshape((B, 1, *XYZ)) + return img + + def _adj_op_sense_device(self, coeffs, img): + raise NotImplementedError + + def _adj_op_calibless_host(self, coeffs, img): + B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + coeffs_f = coeffs.reshape(B * C, NZ * NS) + # Allocate Memory + coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype) + img = np.empty((B * C, *XYZ), dtype=self.cpx_dtype) + for i in range(B * C // T): + ksp_batched.set(coeffs_f[i * T : (i + 1) * T]) + tmp_adj = self.operator.adj_op(ksp_batched) + tmp_adj = tmp_adj.reshape((T, NZ, *XYZ[:2])) + tmp_adj = cp.moveaxis(tmp_adj, 1, -1) + + coil_img_d = cp.zeros_like(coil_img_d) + coil_img_d[..., self.z_index] = tmp_adj + coil_img_d = self.ifftz(coil_img_d) + img[i * T : (i + 1) * T] = coil_img_d.get() + img = img.reshape(B, C, *XYZ) + return img + + def _adj_op_calibless_device(self, coeffs, img): + raise NotImplementedError + + def _adj_op_sense(self, coeffs, img): + B, C, XYZ = self.n_batchs, self.n_coils, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + imgz = np.zeros((B, C, *XYZ), dtype=self.cpx_dtype) + coeffs_ = coeffs.reshape((B, C * NZ, NS)) + for b in range(B): + tmp = np.ascontiguousarray(coeffs_[b, ...]) + tmp_adj = self.operator.adj_op(tmp) + # move the z axis back + tmp_adj = tmp_adj.reshape(C, NZ, *XYZ[:2]) + tmp_adj = np.moveaxis(tmp_adj, 1, -1) + imgz[b][..., self.z_index] = tmp_adj + imgc = self._ifftz(imgz) + img = img or np.empty((B, *XYZ), dtype=self.cpx_dtype) + for b in range(B): + img[b] = np.sum(imgc[b] * self.smaps.conj(), axis=0) + return img + + def _adj_op_calibless(self, coeffs, img): + B, C, XYZ = self.n_batchs, self.n_coils, self.shape + NS, NZ = len(self.samples), len(self.z_index) + + imgz = np.zeros((B, C, *XYZ), dtype=self.cpx_dtype) + coeffs_ = coeffs.reshape((B, C, NZ, NS)) + coeffs_ = coeffs.reshape((B, C * NZ, NS)) + for b in range(B): + t = np.ascontiguousarray(coeffs_[b, ...]) + adj = self.operator.adj_op(t) + # move the z axis back + adj = adj.reshape(C, NZ, *XYZ[:2]) + adj = np.moveaxis(adj, 1, -1) + imgz[b][..., self.z_index] = np.ascontiguousarray(adj) + imgz = np.reshape(imgz, (B, C, *XYZ)) + img = self._ifftz(imgz) + return img + + def traj3d2stacked(samples, dim_z, n_samples=0): """Convert a 3D trajectory into a trajectory and the z-stack index. From 0f96c5ff9a2ab0965327bd2439f1a40648b86c24 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 16 Oct 2023 14:32:57 +0200 Subject: [PATCH 05/15] fix: stacked-cufinufft kwargs. --- src/mrinufft/operators/stacked.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index 91babd20..4a2b71d2 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -273,11 +273,11 @@ def __init__( n_trans=1, squeeze_dims=False, smaps_cached=False, - cufi_kwargs=None, + density=False, + **kwargs, ): if not (CUPY_AVAILABLE and check_backend("cufinufft")): raise RuntimeError("Cupy and cufinufft are required for this backend.") - super().__init__() if (n_batchs * n_coils) % n_trans != 0: raise ValueError("n_batchs * n_coils should be a multiple of n_transf") @@ -318,9 +318,11 @@ def __init__( self.samples, shape[:-1], n_coils=n_trans * len(self.z_index), + n_trans=len(self.z_index), smaps=None, squeeze_dims=True, - **cufi_kwargs, + density=density, + **kwargs, ) @staticmethod From e58f0f5e95fd4895856afd5f8b72153954af61f8 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 16 Oct 2023 14:48:59 +0200 Subject: [PATCH 06/15] fix: init operator earlier. --- src/mrinufft/operators/stacked.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index 4a2b71d2..145eef31 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -290,6 +290,18 @@ def __init__( self.n_batchs = n_batchs self.n_trans = n_trans self.squeeze_dims = squeeze_dims + + self.operator = get_operator("cufinufft")( + self.samples, + shape[:-1], + n_coils=n_trans * len(self.z_index), + n_trans=len(self.z_index), + smaps=None, + squeeze_dims=True, + density=density, + **kwargs, + ) + # Smaps support self.smaps = smaps self.smaps_cached = False @@ -312,19 +324,6 @@ def __init__( self.smaps = pin_memory(smaps.astype(self.cpx_dtype)) self._smap_d = cp.empty(self.shape, dtype=self.cpx_dtype) - if cufi_kwargs is None: - cufi_kwargs = {} - self.operator = get_operator("cufinufft")( - self.samples, - shape[:-1], - n_coils=n_trans * len(self.z_index), - n_trans=len(self.z_index), - smaps=None, - squeeze_dims=True, - density=density, - **kwargs, - ) - @staticmethod def _fftz(data): """Apply FFT on z-axis.""" From 510ce19d1f9e812c850a5cac2dc1a080d44aa1ad Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 16 Oct 2023 14:49:18 +0200 Subject: [PATCH 07/15] fix: use internal device methods. --- src/mrinufft/operators/stacked.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index 145eef31..607de40e 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -391,7 +391,7 @@ def _op_sense_host(self, data, ksp): tmp = cp.moveaxis(tmp, -1, 1) tmp = tmp.reshape(T * NZ, *XYZ[:2]) # After reordering, apply 2D NUFFT - ksp_batched = self.operator.op(cp.ascontiguousarray(tmp)) + ksp_batched = self.operator._op_calibless_device(cp.ascontiguousarray(tmp)) ksp[i * T : (i + 1) * T] = ksp_batched.get() ksp = ksp.reshape((B, C, NZ * NS)) ksp = ksp.reshape((B, C, NZ, NS)) @@ -420,7 +420,7 @@ def _op_calibless_host(self, data, ksp=None): tmp = cp.moveaxis(tmp, -1, 1) tmp = tmp.reshape(T * NZ, *XYZ[:2]) # After reordering, apply 2D NUFFT - ksp_batched = self.operator.op(cp.ascontiguousarray(tmp)) + ksp_batched = self.operator._op_calibless_device(cp.ascontiguousarray(tmp)) ksp[i * T : (i + 1) * T] = ksp_batched.get() ksp = ksp.reshape((B, C, NZ * NS)) @@ -467,7 +467,7 @@ def _adj_op_sense_host(self, coeffs, img): smaps_batched = self.smaps[idx_coils] ksp_batched.set(coeffs_f[i * T : (i + 1) * T]) - tmp_adj = self.operator.adj_op(ksp_batched) + tmp_adj = self.operator._adj_op_calibless_device(ksp_batched) tmp_adj = tmp_adj.reshape((T, NZ, *XYZ[:2])) tmp_adj = cp.moveaxis(tmp_adj, 1, -1) coil_img_d = cp.zeros_like(coil_img_d) @@ -494,7 +494,7 @@ def _adj_op_calibless_host(self, coeffs, img): img = np.empty((B * C, *XYZ), dtype=self.cpx_dtype) for i in range(B * C // T): ksp_batched.set(coeffs_f[i * T : (i + 1) * T]) - tmp_adj = self.operator.adj_op(ksp_batched) + tmp_adj = self.operator._adj_op_calibless_device(ksp_batched) tmp_adj = tmp_adj.reshape((T, NZ, *XYZ[:2])) tmp_adj = cp.moveaxis(tmp_adj, 1, -1) From d7f1261197adfc26805471f286487711dd246178 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 16 Oct 2023 14:51:16 +0200 Subject: [PATCH 08/15] fix(stacked): typo --- src/mrinufft/operators/stacked.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index 607de40e..b0bd0fb1 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -385,7 +385,7 @@ def _op_sense_host(self, data, ksp): cp.copyto(coil_img_d, self.smaps[idx_coils]) coil_img_d *= data_batched # FFT along Z axis (last) - coil_img_d = self.fftz(coil_img_d) + coil_img_d = self._fftz(coil_img_d) coil_img_d = coil_img_d.reshape((T, *XYZ)) tmp = coil_img_d[..., self.z_index] tmp = cp.moveaxis(tmp, -1, 1) @@ -414,7 +414,7 @@ def _op_calibless_host(self, data, ksp=None): for i in range((B * C) // T): coil_img_d.set(dataf[i * T : (i + 1) * T]) - coil_img_d = self.fftz(coil_img_d) + coil_img_d = self._fftz(coil_img_d) coil_img_d = coil_img_d.reshape((T, *XYZ)) tmp = coil_img_d[..., self.z_index] tmp = cp.moveaxis(tmp, -1, 1) @@ -472,7 +472,7 @@ def _adj_op_sense_host(self, coeffs, img): tmp_adj = cp.moveaxis(tmp_adj, 1, -1) coil_img_d = cp.zeros_like(coil_img_d) coil_img_d[..., self.z_index] = tmp_adj - coil_img_d = self.ifftz(coil_img_d) + coil_img_d = self._ifftz(coil_img_d) for t, b in enumerate(idx_batch): img_d[b, :] += coil_img_d[t] * smaps_batched[t].conj() @@ -500,7 +500,7 @@ def _adj_op_calibless_host(self, coeffs, img): coil_img_d = cp.zeros_like(coil_img_d) coil_img_d[..., self.z_index] = tmp_adj - coil_img_d = self.ifftz(coil_img_d) + coil_img_d = self._ifftz(coil_img_d) img[i * T : (i + 1) * T] = coil_img_d.get() img = img.reshape(B, C, *XYZ) return img From 8380064a13b643389e0e82999b55ed338b8f1221 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 16 Oct 2023 14:53:23 +0200 Subject: [PATCH 09/15] fix: correct shape. --- src/mrinufft/operators/stacked.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index b0bd0fb1..269949fc 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -370,9 +370,9 @@ def _op_sense_host(self, data, ksp): data_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) if ksp is None: - ksp = np.empty((B, C, NZ * NS), dtype=self.cpx_dtype) - ksp = ksp.reshape((B * C, NZ * NS)) - ksp_batched = cp.empty((T, NZ * NS), dtype=self.cpx_dtype) + ksp = np.empty((B, C, NZ, NS), dtype=self.cpx_dtype) + ksp = ksp.reshape((B * C, NZ, NS)) + ksp_batched = cp.empty((T, NZ, NS), dtype=self.cpx_dtype) for i in range(B * C // T): idx_coils = np.arange(i * T, (i + 1) * T) % C idx_batch = np.arange(i * T, (i + 1) * T) // C @@ -393,7 +393,6 @@ def _op_sense_host(self, data, ksp): # After reordering, apply 2D NUFFT ksp_batched = self.operator._op_calibless_device(cp.ascontiguousarray(tmp)) ksp[i * T : (i + 1) * T] = ksp_batched.get() - ksp = ksp.reshape((B, C, NZ * NS)) ksp = ksp.reshape((B, C, NZ, NS)) return ksp From 04fe933c54d1b2adc827cc76dde8ce6d84958da0 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 16 Oct 2023 21:52:52 +0200 Subject: [PATCH 10/15] this was the bug --- src/mrinufft/operators/interfaces/cufinufft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 690054ab..214d57c7 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -562,7 +562,7 @@ def _adj_op_calibless_host(self, coeffs, img_batched=None): if self.uses_density: density_batched = cp.repeat(self.density[None, :], T, axis=0).flatten() - img = np.zeros((B, C, *XYZ), dtype=self.cpx_dtype) + img = np.zeros((B * C, *XYZ), dtype=self.cpx_dtype) if img_batched is None: img_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) # TODO: Add concurrency compute batch n while copying batch n+1 to device From edc855cada86d5f8d5948b01e2a26349ac59651b Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 16 Oct 2023 21:53:25 +0200 Subject: [PATCH 11/15] add stacked gpu --- src/mrinufft/operators/stacked.py | 124 ++++++++++++++---------------- 1 file changed, 59 insertions(+), 65 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index 269949fc..95877dc9 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -11,7 +11,13 @@ try: import cupy as cp from cupyx.scipy import fft as cpfft - from .interfaces.utils import is_cuda_array, is_host_array, pin_memory, sizeof_fmt + from .interfaces.utils import ( + is_cuda_array, + is_host_array, + pin_memory, + sizeof_fmt, + get_ptr, + ) except ImportError: CUPY_AVAILABLE = False @@ -301,7 +307,6 @@ def __init__( density=density, **kwargs, ) - # Smaps support self.smaps = smaps self.smaps_cached = False @@ -324,16 +329,23 @@ def __init__( self.smaps = pin_memory(smaps.astype(self.cpx_dtype)) self._smap_d = cp.empty(self.shape, dtype=self.cpx_dtype) + @property + def norm_factor(self): + return self.operator.norm_factor * np.sqrt(2) + @staticmethod def _fftz(data): """Apply FFT on z-axis.""" # sqrt(2) required for normalization return cpfft.fftshift( cpfft.fft( - cpfft.ifftshift(data, axes=-1), axis=-1, norm="ortho", overwrite_x=True + cpfft.ifftshift(data, axes=-1), + axis=-1, + norm="ortho", + overwrite_x=True, ), axes=-1, - ) / np.sqrt(2) + ) @staticmethod def _ifftz(data): @@ -341,10 +353,13 @@ def _ifftz(data): # sqrt(2) required for normalization return cpfft.fftshift( cpfft.ifft( - cpfft.ifftshift(data, axes=-1), axis=-1, norm="ortho", overwrite_x=True + cpfft.ifftshift(data, axes=-1), + axis=-1, + norm="ortho", + overwrite_x=False, ), axes=-1, - ) / np.sqrt(2) + ) def op(self, data, ksp=None): """Forward operator.""" @@ -361,7 +376,7 @@ def op(self, data, ksp=None): return self._safe_squeeze(ret) - def _op_sense_host(self, data, ksp): + def _op_sense_host(self, data, ksp=None): B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape NS, NZ = len(self.samples), len(self.z_index) @@ -371,9 +386,9 @@ def _op_sense_host(self, data, ksp): if ksp is None: ksp = np.empty((B, C, NZ, NS), dtype=self.cpx_dtype) - ksp = ksp.reshape((B * C, NZ, NS)) - ksp_batched = cp.empty((T, NZ, NS), dtype=self.cpx_dtype) - for i in range(B * C // T): + ksp = ksp.reshape((B * C, NZ * NS)) + ksp_batched = cp.empty((T * NZ, NS), dtype=self.cpx_dtype) + for i in range((B * C) // T): idx_coils = np.arange(i * T, (i + 1) * T) % C idx_batch = np.arange(i * T, (i + 1) * T) // C # Send the n_trans coils to gpu @@ -392,8 +407,11 @@ def _op_sense_host(self, data, ksp): tmp = tmp.reshape(T * NZ, *XYZ[:2]) # After reordering, apply 2D NUFFT ksp_batched = self.operator._op_calibless_device(cp.ascontiguousarray(tmp)) + ksp_batched /= self.norm_factor + ksp_batched = ksp_batched.reshape(T, NZ, NS) + ksp_batched = ksp_batched.reshape(T, NZ * NS) ksp[i * T : (i + 1) * T] = ksp_batched.get() - ksp = ksp.reshape((B, C, NZ, NS)) + ksp = ksp.reshape((B, C, NZ * NS)) return ksp def _op_sense_device(self, data, ksp): @@ -403,10 +421,10 @@ def _op_calibless_host(self, data, ksp=None): B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape NS, NZ = len(self.samples), len(self.z_index) - coil_img_d = cp.empty(T, *XYZ, dtype=self.cpx_dtype) - ksp_batched = cp.empty(T, NZ * NS, dtype=self.dtype) + coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) + ksp_batched = cp.empty((T, NZ * NS), dtype=self.dtype) if ksp is None: - ksp = np.zeros((B, C, NZ * NS), dtype=self.cpx_dtype) + ksp = np.zeros((B, C, NZ, NS), dtype=self.cpx_dtype) ksp = ksp.reshape((B * C, NZ * NS)) dataf = data.reshape(B * C, *XYZ) @@ -420,10 +438,12 @@ def _op_calibless_host(self, data, ksp=None): tmp = tmp.reshape(T * NZ, *XYZ[:2]) # After reordering, apply 2D NUFFT ksp_batched = self.operator._op_calibless_device(cp.ascontiguousarray(tmp)) + ksp_batched /= self.norm_factor + ksp_batched = ksp_batched.reshape(T, NZ, NS) + ksp_batched = ksp_batched.reshape(T, NZ * NS) ksp[i * T : (i + 1) * T] = ksp_batched.get() ksp = ksp.reshape((B, C, NZ * NS)) - ksp = ksp.reshape((B, C, NZ, NS)) return ksp def _op_calibless_device(self, data, ksp=None): @@ -445,7 +465,7 @@ def adj_op(self, coeffs, img=None): return self._safe_squeeze(ret) - def _adj_op_sense_host(self, coeffs, img): + def _adj_op_sense_host(self, coeffs, img_d=None): B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape NS, NZ = len(self.samples), len(self.z_index) @@ -455,9 +475,9 @@ def _adj_op_sense_host(self, coeffs, img): if img_d is None: img_d = cp.zeros((B, *XYZ), dtype=self.cpx_dtype) smaps_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype) - ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype) + ksp_batched = cp.empty((T, NS * NZ), dtype=self.cpx_dtype) - for i in range(B * C // T): + for i in range((B * C) // T): idx_coils = np.arange(i * T, (i + 1) * T) % C idx_batch = np.arange(i * T, (i + 1) * T) // C if not self.smaps_cached: @@ -467,9 +487,10 @@ def _adj_op_sense_host(self, coeffs, img): ksp_batched.set(coeffs_f[i * T : (i + 1) * T]) tmp_adj = self.operator._adj_op_calibless_device(ksp_batched) + tmp_adj /= self.norm_factor tmp_adj = tmp_adj.reshape((T, NZ, *XYZ[:2])) tmp_adj = cp.moveaxis(tmp_adj, 1, -1) - coil_img_d = cp.zeros_like(coil_img_d) + coil_img_d[:] = 0j coil_img_d[..., self.z_index] = tmp_adj coil_img_d = self._ifftz(coil_img_d) @@ -482,68 +503,41 @@ def _adj_op_sense_host(self, coeffs, img): def _adj_op_sense_device(self, coeffs, img): raise NotImplementedError - def _adj_op_calibless_host(self, coeffs, img): + def _adj_op_calibless_host(self, coeffs, img=None): B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape NS, NZ = len(self.samples), len(self.z_index) - - coeffs_f = coeffs.reshape(B * C, NZ * NS) + print(B, C, T, XYZ, NS, NZ) + print("coeffs", coeffs.shape) + coeffs_f = coeffs.reshape(B, C, NZ * NS) + coeffs_f = coeffs_f.reshape(B * C, NZ * NS) + print("coeffs_f", coeffs_f.shape) # Allocate Memory + ksp_batched = cp.empty((T, NZ * NS), dtype=self.cpx_dtype) + img = np.zeros((B * C, *XYZ), dtype=self.cpx_dtype) coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) - ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype) - img = np.empty((B * C, *XYZ), dtype=self.cpx_dtype) - for i in range(B * C // T): + for i in range((B * C) // T): + print(i) + ksp_batched = ksp_batched.reshape(T, NZ * NS) ksp_batched.set(coeffs_f[i * T : (i + 1) * T]) + ksp_batched = ksp_batched.reshape(T, NZ, NS) + ksp_batched = ksp_batched.reshape(T * NZ, NS) tmp_adj = self.operator._adj_op_calibless_device(ksp_batched) + tmp_adj /= self.norm_factor tmp_adj = tmp_adj.reshape((T, NZ, *XYZ[:2])) + print(tmp_adj.shape) tmp_adj = cp.moveaxis(tmp_adj, 1, -1) - - coil_img_d = cp.zeros_like(coil_img_d) + print(tmp_adj.shape) + coil_img_d[:] = 0j coil_img_d[..., self.z_index] = tmp_adj coil_img_d = self._ifftz(coil_img_d) - img[i * T : (i + 1) * T] = coil_img_d.get() + img[i * T : (i + 1) * T, ...] = coil_img_d.get() + print(img[i * T : (i + 1) * T, ...].shape) img = img.reshape(B, C, *XYZ) return img def _adj_op_calibless_device(self, coeffs, img): raise NotImplementedError - def _adj_op_sense(self, coeffs, img): - B, C, XYZ = self.n_batchs, self.n_coils, self.shape - NS, NZ = len(self.samples), len(self.z_index) - - imgz = np.zeros((B, C, *XYZ), dtype=self.cpx_dtype) - coeffs_ = coeffs.reshape((B, C * NZ, NS)) - for b in range(B): - tmp = np.ascontiguousarray(coeffs_[b, ...]) - tmp_adj = self.operator.adj_op(tmp) - # move the z axis back - tmp_adj = tmp_adj.reshape(C, NZ, *XYZ[:2]) - tmp_adj = np.moveaxis(tmp_adj, 1, -1) - imgz[b][..., self.z_index] = tmp_adj - imgc = self._ifftz(imgz) - img = img or np.empty((B, *XYZ), dtype=self.cpx_dtype) - for b in range(B): - img[b] = np.sum(imgc[b] * self.smaps.conj(), axis=0) - return img - - def _adj_op_calibless(self, coeffs, img): - B, C, XYZ = self.n_batchs, self.n_coils, self.shape - NS, NZ = len(self.samples), len(self.z_index) - - imgz = np.zeros((B, C, *XYZ), dtype=self.cpx_dtype) - coeffs_ = coeffs.reshape((B, C, NZ, NS)) - coeffs_ = coeffs.reshape((B, C * NZ, NS)) - for b in range(B): - t = np.ascontiguousarray(coeffs_[b, ...]) - adj = self.operator.adj_op(t) - # move the z axis back - adj = adj.reshape(C, NZ, *XYZ[:2]) - adj = np.moveaxis(adj, 1, -1) - imgz[b][..., self.z_index] = np.ascontiguousarray(adj) - imgz = np.reshape(imgz, (B, C, *XYZ)) - img = self._ifftz(imgz) - return img - def traj3d2stacked(samples, dim_z, n_samples=0): """Convert a 3D trajectory into a trajectory and the z-stack index. From 5d760d4bb9764334928257e44607d5bbe6f4b921 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 16 Oct 2023 21:54:08 +0200 Subject: [PATCH 12/15] add stacked gpu tests --- tests/test_stacked_gpu.py | 149 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 tests/test_stacked_gpu.py diff --git a/tests/test_stacked_gpu.py b/tests/test_stacked_gpu.py new file mode 100644 index 00000000..6d57c437 --- /dev/null +++ b/tests/test_stacked_gpu.py @@ -0,0 +1,149 @@ +"""Test for the stacked NUFFT operator. + +The tests compares the stacked NUFFT (which uses FFT in the z-direction) +and the fully 3D ones. +""" +import numpy as np + +import numpy.testing as npt +from pytest_cases import parametrize_with_cases, parametrize, fixture + +from mrinufft.operators.stacked import ( + MRIStackedNUFFTGPU, + stacked2traj3d, + traj3d2stacked, +) +from mrinufft import get_operator +from case_trajectories import CasesTrajectories + + +@fixture(scope="module") +@parametrize( + "n_batchs, n_coils, sense", + [(1, 1, False), (1, 4, False), (1, 4, True), (3, 4, False), (3, 4, True)], +) +@parametrize("z_index", ["full", "random_mask"]) +@parametrize("backend", ["cufinufft"]) +@parametrize_with_cases("kspace_locs, shape", cases=CasesTrajectories, glob="*2D") +def operator(request, backend, kspace_locs, shape, z_index, n_batchs, n_coils, sense): + """Initialize the stacked and non-stacked operators.""" + shape3d = (*shape, shape[-1] - 2) # add a 3rd dimension + + if z_index == "full": + z_index = np.arange(shape3d[-1]) + z_index_ = z_index + elif z_index == "random_mask": + z_index = np.random.rand(shape3d[-1]) > 0.5 + z_index_ = np.arange(shape3d[-1])[z_index] + + kspace_locs3d = stacked2traj3d(kspace_locs, z_index_, shape3d[-1]) + # smaps support + if sense: + smaps = 1j * np.random.rand(n_coils, *shape3d) + smaps += np.random.rand(n_coils, *shape3d) + else: + smaps = None + + # Setup the operators + ref = get_operator(backend)( + kspace_locs3d, + shape=shape3d, + n_coils=n_coils, + n_batchs=n_batchs, + smaps=smaps, + ) + + stacked = MRIStackedNUFFTGPU( + samples=kspace_locs, + shape=shape3d, + z_index=z_index, + n_coils=n_coils, + n_trans=2 if n_coils > 1 else 1, + n_batchs=n_batchs, + smaps=smaps, + ) + return stacked, ref + + +@fixture(scope="module") +def stacked_op(operator): + """Return operator.""" + return operator[0] + + +@fixture(scope="module") +def ref_op(operator): + """Return ref operator.""" + return operator[1] + + +@fixture(scope="module") +def image_data(stacked_op): + """Generate a random image.""" + B, C = stacked_op.n_batchs, stacked_op.n_coils + if stacked_op.smaps is None: + img = np.random.randn(B, C, *stacked_op.shape).astype(stacked_op.cpx_dtype) + elif stacked_op.smaps is not None and stacked_op.n_coils > 1: + img = np.random.randn(B, *stacked_op.shape).astype(stacked_op.cpx_dtype) + + img += 1j * np.random.randn(*img.shape).astype(stacked_op.cpx_dtype) + return img + + +@fixture(scope="module") +def kspace_data(stacked_op): + """Generate a random kspace data.""" + B, C = stacked_op.n_batchs, stacked_op.n_coils + kspace = (1j * np.random.randn(B, C, stacked_op.n_samples)).astype( + stacked_op.cpx_dtype + ) + kspace += np.random.randn(B, C, stacked_op.n_samples).astype(stacked_op.cpx_dtype) + return kspace + + +def test_stack_forward(operator, stacked_op, ref_op, image_data): + """Compare the stack interface to the 3D NUFFT implementation.""" + kspace_nufft = stacked_op.op(image_data).squeeze() + kspace_ref = ref_op.op(image_data).squeeze() + npt.assert_allclose(kspace_nufft, kspace_ref, atol=1e-4, rtol=1e-1) + + +def test_stack_backward(operator, stacked_op, ref_op, kspace_data): + """Compare the stack interface to the 3D NUFFT implementation.""" + image_nufft = stacked_op.adj_op(kspace_data.copy()).squeeze() + image_ref = ref_op.adj_op(kspace_data.copy()).squeeze() + if stacked_op.n_coils > 1: + print( + np.max( + np.abs(image_nufft - image_ref).reshape( + stacked_op.n_batchs, stacked_op.n_coils, -1 + ), + axis=-1, + ) + ) + print(image_nufft.shape, image_ref.shape) + npt.assert_allclose(image_nufft, image_ref, atol=1e-4, rtol=1e-1) + + +def test_stack_auto_adjoint(operator, stacked_op, kspace_data, image_data): + """Test the adjoint property of the stacked NUFFT operator.""" + kspace = stacked_op.op(image_data) + image = stacked_op.adj_op(kspace_data) + leftadjoint = np.vdot(image, image_data) + rightadjoint = np.vdot(kspace, kspace_data) + + npt.assert_allclose(leftadjoint.conj(), rightadjoint, atol=1e-4, rtol=1e-4) + + +def test_stacked2traj3d(): + """Test the conversion from stacked to 3d trajectory.""" + dimz = 64 + traj2d = np.random.randn(100, 2) + z_index = np.random.choice(dimz, 20, replace=False) + + traj3d = stacked2traj3d(traj2d, z_index, dimz) + + traj2d, z_index2 = traj3d2stacked(traj3d, dimz) + + npt.assert_allclose(traj2d, traj2d) + npt.assert_allclose(z_index, z_index2) From b2646716bda794911f7a8f76b5e110334b4c37ca Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 16 Oct 2023 21:58:18 +0200 Subject: [PATCH 13/15] fix: cleanup --- src/mrinufft/operators/stacked.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index 95877dc9..015ce3db 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -16,7 +16,6 @@ is_host_array, pin_memory, sizeof_fmt, - get_ptr, ) except ImportError: CUPY_AVAILABLE = False @@ -331,6 +330,7 @@ def __init__( @property def norm_factor(self): + """Norm factor of the operator.""" return self.operator.norm_factor * np.sqrt(2) @staticmethod @@ -506,17 +506,13 @@ def _adj_op_sense_device(self, coeffs, img): def _adj_op_calibless_host(self, coeffs, img=None): B, C, T, XYZ = self.n_batchs, self.n_coils, self.n_trans, self.shape NS, NZ = len(self.samples), len(self.z_index) - print(B, C, T, XYZ, NS, NZ) - print("coeffs", coeffs.shape) coeffs_f = coeffs.reshape(B, C, NZ * NS) coeffs_f = coeffs_f.reshape(B * C, NZ * NS) - print("coeffs_f", coeffs_f.shape) # Allocate Memory ksp_batched = cp.empty((T, NZ * NS), dtype=self.cpx_dtype) img = np.zeros((B * C, *XYZ), dtype=self.cpx_dtype) coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype) for i in range((B * C) // T): - print(i) ksp_batched = ksp_batched.reshape(T, NZ * NS) ksp_batched.set(coeffs_f[i * T : (i + 1) * T]) ksp_batched = ksp_batched.reshape(T, NZ, NS) @@ -524,14 +520,11 @@ def _adj_op_calibless_host(self, coeffs, img=None): tmp_adj = self.operator._adj_op_calibless_device(ksp_batched) tmp_adj /= self.norm_factor tmp_adj = tmp_adj.reshape((T, NZ, *XYZ[:2])) - print(tmp_adj.shape) tmp_adj = cp.moveaxis(tmp_adj, 1, -1) - print(tmp_adj.shape) coil_img_d[:] = 0j coil_img_d[..., self.z_index] = tmp_adj coil_img_d = self._ifftz(coil_img_d) img[i * T : (i + 1) * T, ...] = coil_img_d.get() - print(img[i * T : (i + 1) * T, ...].shape) img = img.reshape(B, C, *XYZ) return img From 38372dcf6224ef7e03dd27fe2ceb44902c7ce5fa Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 17 Oct 2023 10:34:25 +0200 Subject: [PATCH 14/15] style: sort imports. --- src/mrinufft/operators/stacked.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/mrinufft/operators/stacked.py b/src/mrinufft/operators/stacked.py index 015ce3db..f2c272a9 100644 --- a/src/mrinufft/operators/stacked.py +++ b/src/mrinufft/operators/stacked.py @@ -4,19 +4,19 @@ import numpy as np import scipy as sp -from .base import FourierOperatorBase, proper_trajectory, check_backend -from . import get_operator +from .base import FourierOperatorBase, check_backend, get_operator, proper_trajectory +from .interfaces.utils import ( + is_cuda_array, + is_host_array, + pin_memory, + sizeof_fmt, +) CUPY_AVAILABLE = True try: import cupy as cp from cupyx.scipy import fft as cpfft - from .interfaces.utils import ( - is_cuda_array, - is_host_array, - pin_memory, - sizeof_fmt, - ) + except ImportError: CUPY_AVAILABLE = False From a65716807fc353db96e5df35bc06e56930127501 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 17 Oct 2023 10:38:40 +0200 Subject: [PATCH 15/15] ci: also install cupy. --- .github/workflows/test-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 96f2b4f4..147e0001 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -119,6 +119,7 @@ jobs: cmake -DFINUFFT_USE_CUDA=1 ../ && cmake --build . && cp libcufinufft.so ../python/cufinufft/. # enter venv source $RUNNER_WORKSPACE/venv/bin/activate + pip install cupy-cuda11x cd $RUNNER_WORKSPACE/finufft/python/cufinufft python setup.py develop # FIXME: This is hardcoded