Skip to content

Commit

Permalink
feat: add batch support for gpunufft. (mind-inria#81)
Browse files Browse the repository at this point in the history
* feat: add batch support for gpunufft.

* Make test density stronger

* Fix testing

* Bump gpuNUFFT supported version

* Make sure test does not run for other than gpuNUFFT

* Fix voronoi

* ruff

* Fixes

---------

Co-authored-by: GILIYAR RADHAKRISHNA Chaithya <cg260486@is247382.intra.cea.fr>
  • Loading branch information
2 people authored and chaithyagr committed Apr 11, 2024
1 parent 6c17b8c commit d7521bd
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dynamic = ["version"]

[project.optional-dependencies]

gpunufft = ["gpuNUFFT>=0.7.1", "cupy-cuda11x"]
gpunufft = ["gpuNUFFT>=0.7.4", "cupy-cuda11x"]
cufinufft = ["cufinufft", "cupy-cuda11x"]
finufft = ["finufft"]
pynfft = ["pynfft2", "cython<3.0.0"]
Expand Down
83 changes: 64 additions & 19 deletions src/mrinufft/operators/interfaces/gpunufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(
)
self.pinned_image = pinned_image
self.pinned_kspace = pinned_kspace

self.osf = osf
self.pinned_smaps = pinned_smaps
self.operator = NUFFTOp(
np.reshape(samples, samples.shape[::-1], order="F"),
Expand All @@ -193,8 +193,9 @@ def _reshape_image(self, image, direction="op"):
return xp.asarray([c.ravel(order="F") for c in image], dtype=xp.complex64).T
else:
if self.uses_sense or self.n_coils == 1:
return image.squeeze().astype(xp.complex64).T
return xp.asarray([c.T for c in image], dtype=xp.complex64)
# Support for one additional dimension
return image.squeeze().astype(xp.complex64).T[None]
return xp.asarray([c.T for c in image], dtype=xp.complex64).squeeze()

def op_direct(self, image, kspace=None, interpolate_data=False):
"""Compute the masked non-Cartesian Fourier transform.
Expand Down Expand Up @@ -304,15 +305,16 @@ def adj_op_direct(self, coeffs, image=None, grid_data=False):
adjoint operator of Non Uniform Fourier transform of the
input coefficients.
"""
C = 1 if self.uses_sense else self.n_coils
coeffs = coeffs.astype(cp.complex64)
if image is None:
image = cp.empty(
(np.prod(self.shape), (1 if self.uses_sense else self.n_coils)),
(np.prod(self.shape), C),
dtype=cp.complex64,
order="F",
)
self.operator.adj_op_direct(coeffs.data.ptr, image.data.ptr, grid_data)
image = image.reshape(self.n_coils, *self.shape[::-1])
image = image.reshape(C, *self.shape[::-1])
return self._reshape_image(image, "adjoint")


Expand All @@ -334,10 +336,12 @@ class MRIGpuNUFFT(FourierOperatorBase):
if True, the density compensation is estimated from the samples
locations. If an array is passed, it is used as the density
compensation.
squeeze_dims: bool default True
This has no effect, gpuNUFFT always squeeze the data.
squeeze_dims: bool, default True
If True, will try to remove the singleton dimension for batch and coils.
smaps: np.ndarray default None
Holds the sensitivity maps for SENSE reconstruction.
n_trans: int, default =1
This has no effect for now.
kwargs: extra keyword args
these arguments are passed to gpuNUFFT operator. This is used
only in gpuNUFFT
Expand All @@ -351,9 +355,11 @@ def __init__(
samples,
shape,
n_coils=1,
n_batchs=1,
n_trans=1,
density=None,
smaps=None,
squeeze_dims=False,
squeeze_dims=True,
eps=1e-3,
**kwargs,
):
Expand All @@ -367,8 +373,9 @@ def __init__(
self.samples = proper_trajectory(samples, normalize="unit")
self.dtype = self.samples.dtype
self.n_coils = n_coils
self.n_batchs = n_batchs
self.smaps = smaps

self.squeeze_dims = squeeze_dims
self.compute_density(density)
self.impl = RawGpuNUFFT(
samples=self.samples,
Expand Down Expand Up @@ -396,21 +403,32 @@ def op(self, data, coeffs=None):
np.ndarray
Masked Fourier transform of the input image.
"""
B, C, XYZ, K = self.n_batchs, self.n_coils, self.shape, self.n_samples

op_func = self.impl.op
if is_cuda_array(data):
op_func = self.impl.op_direct
if not self.impl.use_gpu_direct:
warnings.warn(
"Using direct GPU array without passing "
"`use_gpu_direct=True`, this is memory inefficient."
)
return self.impl.op_direct(data, coeffs)
return self.impl.op(
data,
coeffs,
)
data_ = data.reshape((B, 1 if self.uses_sense else C, *XYZ))
if coeffs is not None:
coeffs.reshape((B, C, K))
result = []
for i in range(B):
if coeffs is None:
result.append(op_func(data_[i], None))
else:
op_func(data_[i], coeffs[i])
if coeffs is None:
coeffs = get_array_module(data).stack(result)
return self._safe_squeeze(coeffs)

@with_numpy_cupy
def adj_op(self, coeffs, data=None):
"""Compute adjoint Non Unform Fourier Transform.
"""Compute adjoint Non Uniform Fourier Transform.
Parameters
----------
Expand All @@ -424,14 +442,28 @@ def adj_op(self, coeffs, data=None):
np.ndarray
Inverse discrete Fourier transform of the input coefficients.
"""
B, C, XYZ, K = self.n_batchs, self.n_coils, self.shape, self.n_samples

adj_op_func = self.impl.adj_op
if is_cuda_array(coeffs):
adj_op_func = self.impl.adj_op_direct
if not self.impl.use_gpu_direct:
warnings.warn(
"Using direct GPU array without passing "
"`use_gpu_direct=True`, this is memory inefficient."
)
return self.impl.adj_op_direct(coeffs, data)
return self.impl.adj_op(coeffs, data)
coeffs_ = coeffs.reshape(B, C, K)
if data is not None:
data.reshape((B, 1 if self.uses_sense else C, *XYZ))
result = []
for i in range(B):
if data is None:
result.append(adj_op_func(coeffs_[i], None))
else:
adj_op_func(coeffs_[i], data[i])
if data is None:
data = get_array_module(coeffs).stack(result)
return self._safe_squeeze(data)

@property
def uses_sense(self):
Expand All @@ -457,10 +489,11 @@ def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, **kwargs):
raise ValueError(
"gpuNUFFT is not available, cannot " "estimate the density compensation"
)
volume_shape = (np.array(volume_shape) * osf).astype(int)
grid_op = MRIGpuNUFFT(
samples=kspace_loc,
shape=volume_shape,
osf=osf,
osf=1,
**kwargs,
)
density_comp = grid_op.impl.operator.estimate_density_comp(
Expand All @@ -471,7 +504,6 @@ def pipe(cls, kspace_loc, volume_shape, num_iterations=10, osf=2, **kwargs):
def get_lipschitz_cst(self, max_iter=10, tolerance=1e-5, **kwargs):
"""Return the Lipschitz constant of the operator.
Parameters
----------
max_iter: int
Number of iteration to perform to estimate the Lipschitz constant.
Expand All @@ -497,3 +529,16 @@ def get_lipschitz_cst(self, max_iter=10, tolerance=1e-5, **kwargs):
return tmp_op.impl.operator.get_spectral_radius(
max_iter=max_iter, tolerance=tolerance
)

def _safe_squeeze(self, arr):
"""Squeeze the first two dimensions of shape of the operator."""
if self.squeeze_dims:
try:
arr = arr.squeeze(axis=1)
except ValueError:
pass
try:
arr = arr.squeeze(axis=0)
except ValueError:
pass
return arr
13 changes: 12 additions & 1 deletion tests/case_trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import scipy as sp

from mrinufft.trajectories import initialize_2D_radial
from mrinufft.trajectories.tools import stack
from mrinufft.trajectories.tools import stack, rotate


class CasesTrajectories:
Expand Down Expand Up @@ -34,12 +34,23 @@ def case_radial2D(self, Nc=10, Ns=500, N=64):
trajectory = initialize_2D_radial(Nc, Ns)
return trajectory, (N, N)

def case_nyquist_radial2D(self, Nc=32 * 4, Ns=16, N=32):
"""Create a 2D radial trajectory."""
trajectory = initialize_2D_radial(Nc, Ns)
return trajectory, (N, N)

def case_radial3D(self, Nc=20, Ns=1000, Nr=20, N=64, expansion="rotations"):
"""Create a 3D radial trajectory."""
trajectory = initialize_2D_radial(Nc, Ns)
trajectory = stack(trajectory, nb_stacks=Nr)
return trajectory, (N, N, N)

def case_nyquist_radial3D(self, Nc=32 * 4, Ns=16, Nr=32 * 4, N=32):
"""Create a 3D radial trajectory."""
trajectory = initialize_2D_radial(Nc, Ns)
trajectory = rotate(trajectory, nb_rotations=Nr)
return trajectory, (N, N, N)

def case_grid2D(self, N=16):
"""Create a 2D cartesian grid of frequencies locations."""
freq_1d = sp.fft.fftfreq(N)
Expand Down
8 changes: 7 additions & 1 deletion tests/helpers/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ def assert_correlate(a, b, slope=1.0, slope_err=1e-3, r_value_err=1e-3):
a.flatten(), b.flatten()
)
abs_slope_reg = abs(slope_reg)
if abs(abs_slope_reg - slope) > slope_err:
if r_value_err is not None and abs(rvalue - 1) > r_value_err:
raise AssertionError(
f"RValue {rvalue} != 1 +- {r_value_err}\n "
f"intercept={intercept}, stderr={stderr}, "
f"intercept_stderr={intercept_stderr}"
)
if slope_err is not None and abs(abs_slope_reg - slope) > slope_err:
raise AssertionError(
f"Slope {abs_slope_reg} != {slope} +- {slope_err}\n r={rvalue},"
f"intercept={intercept}, stderr={stderr}, "
Expand Down
5 changes: 4 additions & 1 deletion tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import numpy.testing as npt
import pytest
from pytest_cases import parametrize_with_cases, parametrize, fixture
from helpers import (
assert_correlate,
Expand Down Expand Up @@ -37,7 +38,7 @@
cases=CasesTrajectories,
glob="*random*",
)
@parametrize(backend=["finufft", "cufinufft"])
@parametrize(backend=["gpunufft", "finufft", "cufinufft"])
def operator(
request,
kspace_locs,
Expand All @@ -49,6 +50,8 @@ def operator(
backend="finufft",
):
"""Generate a batch operator."""
if n_trans != 1 and backend == "gpunufft":
pytest.skip("Duplicate case.")
if sense:
smaps = 1j * np.random.rand(n_coils, *shape)
smaps += np.random.rand(n_coils, *shape)
Expand Down
43 changes: 29 additions & 14 deletions tests/test_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import numpy as np
import numpy.testing as npt
from pytest_cases import fixture, parametrize, parametrize_with_cases
from pytest_cases import parametrize, parametrize_with_cases

from case_trajectories import CasesTrajectories
from helpers import assert_correlate
from mrinufft.density import cell_count, voronoi
from mrinufft.density import cell_count, voronoi, pipe
from mrinufft.density.utils import normalize_weights
from mrinufft._utils import proper_trajectory

Expand Down Expand Up @@ -37,15 +37,11 @@ def slow_cell_count2D(traj, shape, osf):
return normalize_weights(weights)


@fixture(scope="module")
def radial_distance():
def radial_distance(traj, shape):
"""Compute the radial distance of a trajectory."""
traj, shape = CasesTrajectories().case_radial2D()

proper_traj = proper_trajectory(traj, normalize="unit")
weights = 2 * np.pi * np.sqrt(proper_traj[:, 0] ** 2 + proper_traj[:, 1] ** 2)

return normalize_weights(weights)
weights = np.linalg.norm(proper_traj, axis=-1)
return weights


@parametrize("osf", [1, 1.25, 2])
Expand All @@ -58,13 +54,32 @@ def test_cell_count2D(traj, shape, osf):


@parametrize_with_cases("traj, shape", cases=[CasesTrajectories.case_radial2D])
def test_voronoi(traj, shape, radial_distance):
def test_voronoi(traj, shape):
"""Test the voronoi method."""
result = voronoi(traj)
distance = radial_distance(traj, shape)
result = result / np.mean(result)
distance = distance / np.mean(distance)
assert_correlate(result, distance, slope=1)

assert_correlate(result, radial_distance, slope=2 * np.pi)


def test_pipe():
@parametrize("osf", [1, 1.25, 2])
@parametrize_with_cases(
"traj, shape",
cases=[
CasesTrajectories.case_nyquist_radial2D,
CasesTrajectories.case_nyquist_radial3D,
],
)
@parametrize(backend=["gpunufft"])
def test_pipe(backend, traj, shape, osf):
"""Test the pipe method."""
pass
distance = radial_distance(traj, shape)
result = pipe(traj, shape, osf=osf, num_iterations=10)
result = result / np.mean(result)
distance = distance / np.mean(distance)
if osf != 2:
# If OSF < 2, we dont perfectly estimate
assert_correlate(result, distance, slope=1, slope_err=None, r_value_err=0.2)
else:
assert_correlate(result, distance, slope=1, slope_err=0.1, r_value_err=0.1)

0 comments on commit d7521bd

Please sign in to comment.