Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plugging a bunch of test holes back to CI #160

Merged
merged 35 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
365afc8
remove use_smap + fix pipe test
Jul 11, 2024
8ef8276
simplifying the pipe test
Jul 12, 2024
1a71159
change osf default value
Jul 12, 2024
746097b
Adding a bunch of test holes back to CI
chaithyagr Jul 12, 2024
edafdfc
reenable finufft
chaithyagr Jul 12, 2024
4c8943b
TEMP commit to just run few tests. TODO to get it back
chaithyagr Jul 12, 2024
813afc6
MRI-NUFFT
chaithyagr Jul 12, 2024
d615a54
fix a bunch of bugs
chaithyagr Jul 15, 2024
351e082
Added tensorflow-probabiloty
chaithyagr Jul 15, 2024
182d42b
Added codes
chaithyagr Jul 15, 2024
d582f55
e-3
chaithyagr Jul 15, 2024
e38622a
A bunch of fixes
chaithyagr Jul 15, 2024
aa53c0c
Added back test
chaithyagr Jul 15, 2024
f9c1759
Move to tensorflow
chaithyagr Jul 15, 2024
a7b9e02
A bunch of fixes
chaithyagr Jul 15, 2024
82520ea
A bunch of factory
chaithyagr Jul 15, 2024
a8b6134
A bunch of more fixes
chaithyagr Jul 15, 2024
873e313
Finally fixed
chaithyagr Jul 15, 2024
1de5374
Finally fixed
chaithyagr Jul 15, 2024
84fa9d9
Fuinal changes
chaithyagr Jul 15, 2024
07021c1
Move back to testing everything
chaithyagr Jul 15, 2024
0560a3d
Fixed test_dens
chaithyagr Jul 16, 2024
f350795
fixed on style
chaithyagr Jul 16, 2024
779ee71
Fixed cufinufft test_update
chaithyagr Jul 16, 2024
e281d03
Make test tighter
chaithyagr Jul 16, 2024
2830cf2
Fix CI
chaithyagr Jul 16, 2024
8d228b4
Fix CI
chaithyagr Jul 16, 2024
c839acb
Disable cancel in progress
chaithyagr Jul 16, 2024
6e34c45
Disable gpuNUFFT
chaithyagr Jul 16, 2024
916694b
Working installs
chaithyagr Jul 16, 2024
b2f447e
Added back gpunufft to ci
chaithyagr Jul 16, 2024
8a76a7d
Minor fixes
chaithyagr Jul 17, 2024
bb79802
Hopefully final fixes
chaithyagr Jul 17, 2024
c76d65a
Taking in comments
chaithyagr Jul 17, 2024
8b52a13
Fix
chaithyagr Jul 17, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ jobs:
if: ${{ !contains(github.event.head_commit.message, 'style')}}
strategy:
matrix:
backend: [gpunufft, cufinufft, torchkbnufft-gpu]
backend: [cufinufft, gpunufft, torchkbnufft-gpu, tensorflow]

steps:
- uses: actions/checkout@v3
Expand All @@ -123,20 +123,20 @@ jobs:


- name: Install backend
if: ${{ matrix.backend == 'gpunufft' || matrix.backend == 'cufinufft' || matrix.backend == 'torchkbnufft-gpu' }}
shell: bash
run: |
source $RUNNER_WORKSPACE/venv/bin/activate
export CUDA_BIN_PATH=/usr/local/cuda-11.8/
export PATH=/usr/local/cuda-11.8/bin/${PATH:+:${PATH}}
export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib/{LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
export PATH=/usr/local/cuda-11.8/bin/:${PATH}
export LD_LIBRARY_PATH=/usr/local/cuda-11.8/lib64/:${LD_LIBRARY_PATH}
if [[ ${{ matrix.backend }} == "torchkbnufft-gpu" ]]; then
pip install torchkbnufft
elif [[ ${{ matrix.backend }} == "tensorflow" ]]; then
pip install tensorflow-mri==0.21.0 tensorflow-probability==0.17.0 tensorflow-io==0.27.0 matplotlib==3.7
else
pip install ${{ matrix.backend }}
fi


- name: Run Tests
shell: bash
run: |
Expand Down
6 changes: 3 additions & 3 deletions src/mrinufft/operators/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def backward(ctx, dy):
if ctx.nufft_op._grad_wrt_traj:
im_size = x.size()[1:]
factor = 1
if ctx.nufft_op.backend in ["gpunufft", "finufft"]:
if ctx.nufft_op.backend in ["gpunufft"]:
factor *= np.pi * 2
r = [
torch.linspace(-size / 2, size / 2 - 1, size) * factor
Expand Down Expand Up @@ -88,7 +88,7 @@ def backward(ctx, dx):
ctx.nufft_op.raw_op.toggle_grad_traj()
im_size = dx.size()[2:]
factor = 1
if ctx.nufft_op.backend in ["gpunufft", "finufft"]:
if ctx.nufft_op.backend in ["gpunufft"]:
factor *= np.pi * 2
r = [
torch.linspace(-size / 2, size / 2 - 1, size) * factor
Expand Down Expand Up @@ -133,7 +133,7 @@ def __init__(self, nufft_op, wrt_data=True, wrt_traj=False):
self.nufft_op = nufft_op
self.nufft_op._grad_wrt_traj = wrt_traj
if wrt_traj and self.nufft_op.backend in ["finufft", "cufinufft"]:
self.nufft_op.raw_op._make_plan_grad()
self.nufft_op._make_plan_grad()
self.nufft_op._grad_wrt_data = wrt_data

def op(self, x):
Expand Down
31 changes: 29 additions & 2 deletions src/mrinufft/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,33 @@ def wrapper(self, data, *args, **kwargs):
return wrapper


def with_tensorflow(fun):
"""Ensure the function works internally with tensorflow array."""

@wraps(fun)
def wrapper(self, data, *args, **kwargs):
import tensorflow as tf

xp = get_array_module(data)
if xp.__name__ == "torch":
data_ = tf.convert_to_tensor(data.cpu())
elif xp.__name__ == "cupy":
data_ = tf.experimental.dlpack.from_dlpack(data.toDlpack())
else:
data_ = tf.convert_to_tensor(data)

ret_ = fun(self, data_, *args, **kwargs)

if xp.__name__ in ["torch", "cupy"]:
return xp.from_dlpack(tf.experimental.dlpack.to_dlpack(ret_))
elif xp.__name__ == "numpy":
return ret_.numpy()
else:
return ret_

return wrapper


def with_numpy_cupy(fun):
"""Ensure the function works internally with numpy or cupy array."""

Expand Down Expand Up @@ -589,7 +616,7 @@ def __init__(
self.shape = shape

# we will access the samples by their coordinate first.
self.samples = samples.reshape(-1, len(shape))
self._samples = samples.reshape(-1, len(shape))
self.dtype = self.samples.dtype
if n_coils < 1:
raise ValueError("n_coils should be ≥ 1")
Expand Down Expand Up @@ -753,7 +780,7 @@ def _grad_sense(self, image_data, obs_data):

dataf = image_data.reshape((B, *XYZ))
obs_dataf = obs_data.reshape((B * C, K))
grad = np.empty_like(dataf)
grad = np.zeros_like(dataf)
chaithyagr marked this conversation as resolved.
Show resolved Hide resolved

coil_img = np.empty((T, *XYZ), dtype=self.cpx_dtype)
coil_ksp = np.empty((T, K), dtype=self.cpx_dtype)
Expand Down
47 changes: 24 additions & 23 deletions src/mrinufft/operators/interfaces/cufinufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,56 +58,43 @@ def __init__(
**kwargs,
):
self.shape = shape
self.samples = samples
self.ndim = len(shape)
chaithyagr marked this conversation as resolved.
Show resolved Hide resolved
self.eps = float(eps)
self.n_trans = n_trans

self._dtype = samples.dtype
# the first element is dummy to index type 1 with 1
# and type 2 with 2.
self.plans = [None, None, None]
self.grad_plan = None

for i in [1, 2]:
self._make_plan(i, **kwargs)
self._set_pts(i)
self._set_pts(i, samples)

@property
def dtype(self):
"""Return the dtype (precision) of the transform."""
try:
return self.plans[1].dtype
except AttributeError:
return DTYPE_R2C[str(self.samples.dtype)]
return DTYPE_R2C[str(self._dtype)]

def _make_plan(self, typ, **kwargs):
self.plans[typ] = Plan(
typ,
self.shape,
self.n_trans,
self.eps,
dtype=DTYPE_R2C[str(self.samples.dtype)],
**kwargs,
)

def _make_plan_grad(self, **kwargs):
self.grad_plan = Plan(
2,
self.shape,
self.n_trans,
self.eps,
dtype=DTYPE_R2C[str(self.samples.dtype)],
isign=1,
dtype=DTYPE_R2C[str(self._dtype)],
**kwargs,
)
self._set_pts(typ="grad")

def _set_pts(self, typ):
def _set_pts(self, typ, samples):
plan = self.grad_plan if typ == "grad" else self.plans[typ]
plan.setpts(
cp.array(self.samples[:, 0], copy=False),
cp.array(self.samples[:, 1], copy=False),
cp.array(self.samples[:, 2], copy=False) if self.ndim == 3 else None,
cp.array(samples[:, 0], copy=False),
cp.array(samples[:, 1], copy=False),
cp.array(samples[:, 2], copy=False) if self.ndim == 3 else None,
)

def _destroy_plan(self, typ):
Expand Down Expand Up @@ -267,11 +254,13 @@ def __init__(
@FourierOperatorBase.samples.setter
def samples(self, samples):
"""Update the plans when changing the samples."""
self._samples = samples
self._samples = np.asfortranarray(
proper_trajectory(samples, normalize="pi").astype(np.float32, copy=False)
)
for typ in [1, 2, "grad"]:
if typ == "grad" and not self._grad_wrt_traj:
continue
self.raw_op._set_pts(typ)
self.raw_op._set_pts(typ, samples)

@with_numpy_cupy
@nvtx_mark()
Expand Down Expand Up @@ -792,6 +781,18 @@ def __repr__(self):
")"
)

def _make_plan_grad(self, **kwargs):
self.raw_op.grad_plan = Plan(
2,
self.shape,
self.n_trans,
self.raw_op.eps,
dtype=DTYPE_R2C[str(self.samples.dtype)],
isign=1,
**kwargs,
)
self.raw_op._set_pts(typ="grad", samples=self.samples)

def get_lipschitz_cst(self, max_iter=10, **kwargs):
"""Return the Lipschitz constant of the operator.

Expand Down
53 changes: 28 additions & 25 deletions src/mrinufft/operators/interfaces/finufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,61 +26,48 @@ def __init__(
**kwargs,
):
self.shape = shape
self.samples = proper_trajectory(np.asfortranarray(samples), normalize="pi")
self.ndim = len(shape)
self.eps = float(eps)
self.n_trans = n_trans

self.n_samples = len(samples)
# the first element is dummy to index type 1 with 1
# and type 2 with 2.
self.plans = [None, None, None]
self.grad_plan = None

for i in [1, 2]:
self._make_plan(i, **kwargs)
self._set_pts(i)
self._make_plan(i, samples, **kwargs)
self._set_pts(i, samples)

def _make_plan(self, typ, **kwargs):
def _make_plan(self, typ, samples, **kwargs):
self.plans[typ] = Plan(
typ,
self.shape,
self.n_trans,
self.eps,
dtype=DTYPE_R2C[str(self.samples.dtype)],
**kwargs,
)

def _make_plan_grad(self, **kwargs):
self.grad_plan = Plan(
2,
self.shape,
self.n_trans,
self.eps,
dtype=DTYPE_R2C[str(self.samples.dtype)],
isign=1,
dtype=DTYPE_R2C[str(samples.dtype)],
**kwargs,
)
self._set_pts(typ="grad")

def _set_pts(self, typ):
def _set_pts(self, typ, samples):
fpts_axes = [None, None, None]
for i in range(self.ndim):
fpts_axes[i] = np.array(self.samples[:, i], dtype=self.samples.dtype)
fpts_axes[i] = np.array(samples[:, i], dtype=samples.dtype)
plan = self.grad_plan if typ == "grad" else self.plans[typ]
plan.setpts(*fpts_axes)

def adj_op(self, coeffs_data, grid_data):
"""Type 1 transform. Non Uniform to Uniform."""
if self.n_trans == 1:
grid_data = grid_data.reshape(self.shape)
coeffs_data = coeffs_data.reshape(len(self.samples))
coeffs_data = coeffs_data.reshape(self.n_samples)
return self.plans[1].execute(coeffs_data, grid_data)

def op(self, coeffs_data, grid_data):
"""Type 2 transform. Uniform to non-uniform."""
if self.n_trans == 1:
grid_data = grid_data.reshape(self.shape)
coeffs_data = coeffs_data.reshape(len(self.samples))
coeffs_data = coeffs_data.reshape(self.n_samples)
return self.plans[2].execute(grid_data, coeffs_data)

def toggle_grad_traj(self):
Expand Down Expand Up @@ -133,6 +120,7 @@ def __init__(
squeeze_dims=True,
**kwargs,
):
samples = proper_trajectory(np.asfortranarray(samples), normalize="pi")
self.raw_op = RawFinufftPlan(
samples,
shape,
Expand All @@ -152,10 +140,25 @@ def __init__(
)

@FourierOperatorBase.samples.setter
def samples(self, samples):
def samples(self, new_samples):
"""Update the plans when changing the samples."""
self._samples = samples
self._samples = proper_trajectory(
np.asfortranarray(new_samples), normalize="pi"
)
for typ in [1, 2, "grad"]:
if typ == "grad" and not self._grad_wrt_traj:
continue
self.raw_op._set_pts(typ)
self.raw_op._set_pts(typ, new_samples)
self.compute_density(self.density_method)

def _make_plan_grad(self, **kwargs):
self.raw_op.grad_plan = Plan(
2,
self.raw_op.shape,
self.raw_op.n_trans,
self.raw_op.eps,
dtype=DTYPE_R2C[str(self.samples.dtype)],
isign=1,
**kwargs,
)
self.raw_op._set_pts(typ="grad", samples=self.samples)
8 changes: 6 additions & 2 deletions src/mrinufft/operators/interfaces/gpunufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,12 +511,16 @@ def samples(self, samples):
samples: np.ndarray
The samples for the Fourier Operator.
"""
self._samples = proper_trajectory(
samples.astype(np.float32, copy=False), normalize="unit"
)
# TODO: gpuNUFFT needs to sort the points twice in this case.
# It could help to have access to directly dorted arrays from gpuNUFFT.
self.compute_density(self.density_method)
self.raw_op.set_pts(
samples,
self._samples,
density=self.density,
)
self._samples = samples

@classmethod
def pipe(
Expand Down
Loading
Loading