Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct Torch and CUPY support for gpUNUFFT arrays #80

Merged
merged 15 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ jobs:
path: coverage_${{ matrix.backend}}

test-gpu:
runs-on: GPU
runs-on: gpu
if: ${{ !contains(github.event.head_commit.message, 'style')}}
strategy:
matrix:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dynamic = ["version"]

[project.optional-dependencies]

gpunufft = ["gpunufft", "cupy-cuda11x"]
gpunufft = ["gpuNUFFT>=0.7.1", "cupy-cuda11x"]
cufinufft = ["cufinufft", "cupy-cuda11x"]
finufft = ["finufft"]
pynfft = ["pynfft2", "cython<3.0.0"]
Expand Down
40 changes: 40 additions & 0 deletions src/mrinufft/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,16 @@
from functools import partial, wraps
import numpy as np
from mrinufft._utils import power_method, auto_cast, get_array_module
from mrinufft.operators.interfaces.utils import is_cuda_array

from mrinufft.density import get_density

CUPY_AVAILABLE = True
try:
import cupy as cp
except ImportError:
CUPY_AVAILABLE = False

# Mapping between numpy float and complex types.
DTYPE_R2C = {"float32": "complex64", "float64": "complex128"}

Expand Down Expand Up @@ -114,6 +121,39 @@ def wrapper(self, data, *args, **kwargs):
return wrapper


def with_numpy_cupy(fun):
"""Ensure the function works internally with numpy or cupy array."""

@wraps(fun)
def wrapper(self, data, output=None, *args, **kwargs):
xp = get_array_module(data)
if xp.__name__ == "torch" and is_cuda_array(data):
# Move them to cupy
data_ = cp.from_dlpack(data)
output_ = cp.from_dlpack(output) if output is not None else None
elif xp.__name__ == "torch":
# Move to numpy
data_ = data.to("cpu").numpy()
output_ = output.to("cpu").numpy() if output is not None else None
else:
data_ = data
output_ = output

ret_ = fun(self, data_, output_, *args, **kwargs)

if xp.__name__ == "torch" and is_cuda_array(data):
return xp.as_tensor(ret_, device=data.device)

if xp.__name__ == "torch":
if data.is_cpu:
return xp.from_numpy(ret_)
return xp.from_numpy(ret_).to(data.device)

return ret_

return wrapper


class FourierOperatorBase(ABC):
"""Base Fourier Operator class.

Expand Down
18 changes: 3 additions & 15 deletions src/mrinufft/operators/interfaces/cufinufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import warnings
import numpy as np
from mrinufft.operators.base import FourierOperatorBase
from mrinufft.operators.base import FourierOperatorBase, with_numpy_cupy
from mrinufft._utils import (
proper_trajectory,
get_array_module,
Expand Down Expand Up @@ -239,6 +239,7 @@ def __init__(
)
# Support for concurrent stream and computations.

@with_numpy_cupy
@nvtx_mark()
def op(self, data, ksp_d=None):
r"""Non Cartesian MRI forward operator.
Expand All @@ -262,9 +263,6 @@ def op(self, data, ksp_d=None):
check_size(data, (self.n_batchs, *self.shape))
else:
check_size(data, (self.n_batchs, self.n_coils, *self.shape))
xp = get_array_module(data)
if xp.__name__ == "torch" and data.is_cpu:
data = data.numpy()
data = auto_cast(data, self.cpx_dtype)
# Dispatch to special case.
if self.uses_sense and is_cuda_array(data):
Expand All @@ -278,10 +276,6 @@ def op(self, data, ksp_d=None):
ret = op_func(data, ksp_d)

ret /= self.norm_factor
if xp.__name__ == "torch" and is_cuda_array(ret):
ret = xp.as_tensor(ret, device=data.device)
elif xp.__name__ == "torch":
ret = xp.from_numpy(ret)
return self._safe_squeeze(ret)

def _op_sense_device(self, data, ksp_d=None):
Expand Down Expand Up @@ -368,6 +362,7 @@ def __op(self, image_d, coeffs_d):
return self.raw_op.type2(image_d, coeffs_d)

@nvtx_mark()
@with_numpy_cupy
def adj_op(self, coeffs, img_d=None):
"""Non Cartesian MRI adjoint operator.

Expand All @@ -379,9 +374,6 @@ def adj_op(self, coeffs, img_d=None):
-------
Array in the same memory space of coeffs. (ie on cpu or gpu Memory).
"""
xp = get_array_module(coeffs)
if xp.__name__ == "torch" and coeffs.is_cpu:
coeffs = coeffs.numpy()
coeffs = auto_cast(coeffs, self.cpx_dtype)
check_size(coeffs, (self.n_batchs, self.n_coils, self.n_samples))
# Dispatch to special case.
Expand All @@ -397,10 +389,6 @@ def adj_op(self, coeffs, img_d=None):
ret = adj_op_func(coeffs, img_d)
ret /= self.norm_factor

if xp.__name__ == "torch" and is_cuda_array(ret):
ret = xp.as_tensor(ret, device=coeffs.device)
elif xp.__name__ == "torch":
ret = xp.from_numpy(ret)
return self._safe_squeeze(ret)

def _adj_op_sense_device(self, coeffs, img_d=None):
Expand Down
Loading
Loading