Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Autodiff wrt to trajectory #116

Merged
merged 55 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
0138abd
update for the
May 24, 2024
f2bbe52
change the position of the parameters in autodiff
May 24, 2024
2a4c2e1
change the positionof the functions and the comments
May 24, 2024
638f55c
update for checkingthe style
May 24, 2024
11d770f
MINOR Extra changes, Please remove this changes
May 27, 2024
71579bc
update for forward part
May 29, 2024
a19f242
update for forward
alineyyy May 29, 2024
016c870
forward
alineyyy May 29, 2024
cd70dbf
update for forward
alineyyy May 29, 2024
1d7f30e
update for adjoint
alineyyy May 30, 2024
d53a7ca
Fixed working codes
chaithyagr May 30, 2024
df5006f
Remove vscode stuff
chaithyagr May 30, 2024
5c29929
Ruff and black
chaithyagr May 30, 2024
b7b7255
Merge branch 'master' into autodiff_ktraj
chaithyagr May 30, 2024
156f322
Merging
chaithyagr May 30, 2024
6b5ef8e
Remove torch dependence
chaithyagr May 30, 2024
f6a6e7a
Remove bad usage of torch
chaithyagr May 30, 2024
0dca414
Fix get_op
chaithyagr May 30, 2024
fa5a603
Added squeeze dims check right
chaithyagr May 30, 2024
e1264d8
Fixes
chaithyagr May 31, 2024
3aa6574
update for finufft
alineyyy May 31, 2024
c086897
Add support to gpuNUFFT
chaithyagr May 31, 2024
38e869d
Merge branch 'autodiff_ktraj' of github.com:mind-inria/mri-nufft into…
chaithyagr May 31, 2024
f21fb29
Add su-port for gpunufft
chaithyagr May 31, 2024
713ac66
Delete .vscode directory
alineyyy May 31, 2024
a0bc4ea
Merge branch 'master' into autodiff_ktraj
alineyyy May 31, 2024
abd9ffa
fix test_bindings
alineyyy Jun 2, 2024
f41b8d6
Merge remote-tracking branch 'refs/remotes/origin/autodiff_ktraj' int…
alineyyy Jun 2, 2024
44d8579
fix gpunufft pipe
alineyyy Jun 2, 2024
af33d5c
fix test-cpu
alineyyy Jun 2, 2024
b82eb90
change test-ci
alineyyy Jun 2, 2024
10f4b0b
black style check
alineyyy Jun 2, 2024
1fa060f
Delete .vscode directory
alineyyy Jun 2, 2024
8499e3b
Merge branch 'master' into autodiff_ktraj
chaithyagr Jun 3, 2024
7bcec2b
Moving the test_autodiff to operators, so that it is tested
chaithyagr Jun 3, 2024
b501a0c
update for comments
alineyyy Jun 3, 2024
e936fca
fix get_fourier_matrix
alineyyy Jun 3, 2024
c60c641
style check
alineyyy Jun 3, 2024
c52dc2f
style check
alineyyy Jun 3, 2024
c279683
update for PAC'S comments
alineyyy Jun 4, 2024
ebc1f56
fix autograd_available
alineyyy Jun 4, 2024
9f7c5d7
fix autograd_available
alineyyy Jun 4, 2024
cfa781b
fix style checking
alineyyy Jun 4, 2024
6d4e4c5
reduce the tolerance in test_autodiff
alineyyy Jun 4, 2024
b906980
set the proper tolerance in test_autodiff
alineyyy Jun 4, 2024
dbce464
update for comments
alineyyy Jun 6, 2024
6879643
Merge branch 'master' into autodiff_ktraj
chaithyagr Jun 6, 2024
aef0b8f
black style checking
alineyyy Jun 6, 2024
8e6d111
Merge remote-tracking branch 'refs/remotes/origin/autodiff_ktraj' int…
alineyyy Jun 6, 2024
5e591a7
black style checking
alineyyy Jun 6, 2024
7ff8061
fix get_fourier_matrix
alineyyy Jun 6, 2024
88c374d
Delete .vscode directory
alineyyy Jun 6, 2024
a5744b1
update .gitignore
alineyyy Jun 6, 2024
48422b9
Merge remote-tracking branch 'refs/remotes/origin/autodiff_ktraj' int…
alineyyy Jun 6, 2024
127899c
Merge branch 'master' into autodiff_ktraj
paquiteau Jun 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ dist/
examples/*.ipynb
*.xml
.coverage*
.vscode


.idea
*.log
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ dynamic = ["version"]

[project.optional-dependencies]

gpunufft = ["gpuNUFFT>=0.7.5", "cupy-cuda11x"]
gpunufft = ["gpuNUFFT>=0.8.0", "cupy-cuda11x"]
cufinufft = ["cufinufft", "cupy-cuda11x"]
finufft = ["finufft"]
pynfft = ["pynfft2", "cython<3.0.0"]
pynufft = ["pynufft"]
io = ["pymapvbvd"]
smaps = ["scikit-image"]
autodiff = ["torch"]


test = ["pytest<8.0.0", "pytest-cov", "pytest-xdist", "pytest-sugar", "pytest-cases"]
dev = ["black", "isort", "ruff"]
Expand Down
21 changes: 15 additions & 6 deletions src/mrinufft/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,25 +75,34 @@ def proper_trajectory(trajectory, normalize="pi"):
The normalized trajectory of shape (Nc * Ns, dim) or (Ns, dim) in -pi, pi
"""
# flatten to a list of point
xp = get_array_module(trajectory) # check if the trajectory is a tensor
try:
new_traj = np.asarray(trajectory).copy()
new_traj = (
trajectory.clone()
if xp.__name__ == "torch"
else np.asarray(trajectory).copy()
)
except Exception as e:
raise ValueError(
"trajectory should be array_like, with the last dimension being coordinates"
) from e

new_traj = new_traj.reshape(-1, trajectory.shape[-1])

if normalize == "pi" and np.max(abs(new_traj)) - 1e-4 < 0.5:
max_abs_val = xp.max(xp.abs(new_traj))

if normalize == "pi" and max_abs_val - 1e-4 < 0.5:
warnings.warn(
"Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)"
)
new_traj *= 2 * np.pi
elif normalize == "unit" and np.max(abs(new_traj)) - 1e-4 > 0.5:
new_traj *= 2 * xp.pi
elif normalize == "unit" and max_abs_val - 1e-4 > 0.5:
warnings.warn(
"Samples will be rescaled to [-0.5, 0.5), assuming they were in [-pi, pi)"
)
new_traj /= 2 * np.pi
if normalize == "unit" and np.max(new_traj) >= 0.5:
new_traj *= 1 / (2 * xp.pi)

if normalize == "unit" and max_abs_val >= 0.5:
new_traj = (new_traj + 0.5) % 1 - 0.5
return new_traj

Expand Down
99 changes: 85 additions & 14 deletions src/mrinufft/operators/autodiff.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,106 @@
"""Torch autodifferentiation for MRI-NUFFT."""

import torch
import numpy as np


class _NUFFT_OP(torch.autograd.Function):
"""Autograd support for op nufft function."""
"""
Autograd support for op nufft function.

This class is implemented by an efficient approximation of Jacobian Matrices.

References
----------
Wang G, Fessler J A. "Efficient approximation of Jacobian matrices involving a
chaithyagr marked this conversation as resolved.
Show resolved Hide resolved
non-uniform fast Fourier transform (NUFFT)."
IEEE Transactions on Computational Imaging, 2023, 9: 43-54.
"""

@staticmethod
def forward(ctx, x, nufft_op):
def forward(ctx, x, traj, nufft_op):
"""Forward image -> k-space."""
ctx.save_for_backward(x)
ctx.save_for_backward(x, traj)
ctx.nufft_op = nufft_op
return nufft_op.op(x)

@staticmethod
def backward(ctx, dy):
"""Backward image -> k-space."""
(x,) = ctx.saved_tensors
return ctx.nufft_op.adj_op(dy), None
(x, traj) = ctx.saved_tensors
grad_data = None
grad_traj = None
if ctx.nufft_op._grad_wrt_data:
grad_data = ctx.nufft_op.adj_op(dy)
if ctx.nufft_op._grad_wrt_traj:
im_size = x.size()[1:]
factor = 1
if ctx.nufft_op.backend == "gpunufft":
factor *= np.pi * 2
r = [
torch.linspace(-size / 2, size / 2 - 1, size) * factor
for size in im_size
]
grid_r = torch.meshgrid(*r, indexing="ij")
grid_r = torch.stack(grid_r, dim=0).type_as(x)[None, ...]

grid_x = x * grid_r # Element-wise multiplication: x * r
nufft_dx_dom = torch.cat(
[ctx.nufft_op.op(grid_x[:, i, :, :]) for i in range(grid_x.size(1))],
dim=1,
)

grad_traj = torch.transpose(
(-1j * torch.conj(dy) * nufft_dx_dom).squeeze(), 0, 1
).type_as(traj)

return grad_data, grad_traj, None


class _NUFFT_ADJOP(torch.autograd.Function):
chaithyagr marked this conversation as resolved.
Show resolved Hide resolved
"""Autograd support for adj_op nufft function."""

@staticmethod
def forward(ctx, y, nufft_op):
def forward(ctx, y, traj, nufft_op):
"""Forward kspace -> image."""
ctx.save_for_backward(y)
ctx.save_for_backward(y, traj)
ctx.nufft_op = nufft_op
return nufft_op.adj_op(y)

@staticmethod
def backward(ctx, dx):
"""Backward kspace -> image."""
(y,) = ctx.saved_tensors
return ctx.nufft_op.op(dx), None
(y, traj) = ctx.saved_tensors
grad_data = None
grad_traj = None
if ctx.nufft_op._grad_wrt_data:
grad_data = ctx.nufft_op.op(dx)
if ctx.nufft_op._grad_wrt_traj:
ctx.nufft_op.raw_op.toggle_grad_traj()
im_size = dx.size()[2:]
factor = 1
if ctx.nufft_op.backend == "gpunufft":
factor *= np.pi * 2
r = [
torch.linspace(-size / 2, size / 2 - 1, size) * factor
for size in im_size
]
grid_r = torch.meshgrid(*r, indexing="ij")
grid_r = torch.stack(grid_r, dim=0).type_as(dx)[None, ...]

grid_dx = torch.conj(dx) * grid_r
inufft_dx_dom = torch.cat(
[ctx.nufft_op.op(grid_dx[:, i, :, :]) for i in range(grid_dx.size(1))],
dim=1,
)

grad_traj = torch.transpose(
(1j * y * inufft_dx_dom).squeeze(), 0, 1
).type_as(traj)

ctx.nufft_op.raw_op.toggle_grad_traj()

return grad_data, grad_traj, None


class MRINufftAutoGrad(torch.nn.Module):
Expand All @@ -46,19 +112,24 @@ class MRINufftAutoGrad(torch.nn.Module):
nufft_op: Classic Non differentiable MRI-NUFFT operator.
"""

def __init__(self, nufft_op):
def __init__(self, nufft_op, wrt_data=True, wrt_traj=False):
super().__init__()
if nufft_op.squeeze_dims:
raise ValueError("Squeezing dimensions is not " "supported for autodiff.")
if (wrt_data or wrt_traj) and nufft_op.squeeze_dims:
raise ValueError("Squeezing dimensions is not supported for autodiff.")

self.nufft_op = nufft_op
self.nufft_op._grad_wrt_traj = wrt_traj
if wrt_traj and self.nufft_op.backend in ["finufft", "cufinufft"]:
self.nufft_op.raw_op._make_plan_grad()
self.nufft_op._grad_wrt_data = wrt_data

def op(self, x):
r"""Compute the forward image -> k-space."""
return _NUFFT_OP.apply(x, self.nufft_op)
return _NUFFT_OP.apply(x, self.samples, self.nufft_op)

def adj_op(self, kspace):
r"""Compute the adjoint k-space -> image."""
return _NUFFT_ADJOP.apply(kspace, self.nufft_op)
return _NUFFT_ADJOP.apply(kspace, self.samples, self.nufft_op)

def __getattr__(self, name):
"""Get the attribute from the root operator."""
Expand Down
46 changes: 27 additions & 19 deletions src/mrinufft/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,19 @@ def list_backends(available_only=False):
]


def get_operator(backend_name: str, *args, autograd=None, **kwargs):
def get_operator(
backend_name: str, wrt_data: bool = False, wrt_traj: bool = False, *args, **kwargs
):
"""Return an MRI Fourier operator interface using the correct backend.

Parameters
----------
backend_name: str
Backend name

autograd: str, default None
if set to "data" will provide an operator with autodiff capabilities with
respect to it.

wrt_data: bool, default False
if set gradients wrt to data and images will be available.
wrt_traj: bool, default False
if set gradients wrt to trajectory will be available.
*args, **kwargs:
Arguments to pass to the operator constructor.

Expand Down Expand Up @@ -107,11 +108,11 @@ class or instance of class if args or kwargs are given.
if args or kwargs:
operator = operator(*args, **kwargs)

if autograd:
if isinstance(operator, FourierOperatorBase):
operator = operator.make_autograd(variable=autograd)
else: # partial
operator = partial(operator.with_autograd, variable=autograd)
# if autograd:
if isinstance(operator, FourierOperatorBase):
operator = operator.make_autograd(wrt_data, wrt_traj)
elif wrt_data or wrt_traj: # instance will be created later
operator = partial(operator.with_autograd, wrt_data, wrt_traj)
return operator


Expand Down Expand Up @@ -195,6 +196,7 @@ class FourierOperatorBase(ABC):
"""

interfaces: dict[str, tuple] = {}
autograd_available = False

def __init__(self):
if not self.available:
Expand Down Expand Up @@ -294,14 +296,20 @@ def compute_smaps(self, method=None):
**kwargs,
)

def make_autograd(self, variable="data"):
def make_autograd(self, wrt_data=True, wrt_traj=False):
"""Make a new Operator with autodiff support.

Parameters
----------
variable: str, default data
variable: , default data
variable on which the gradient is computed with respect to.

paquiteau marked this conversation as resolved.
Show resolved Hide resolved
wrt_data : bool, optional
If the gradient with respect to the data is computed, default is true

wrt_traj : bool, optional
If the gradient with respect to the trajectory is computed, default is false

Returns
-------
torch.nn.module
Expand All @@ -314,10 +322,10 @@ def make_autograd(self, variable="data"):
"""
if not AUTOGRAD_AVAILABLE:
raise ValueError("Autograd not available, ensure torch is installed.")
paquiteau marked this conversation as resolved.
Show resolved Hide resolved
if variable == "data":
return MRINufftAutoGrad(self)
else:
raise ValueError(f"Autodiff with respect to {variable} is not supported.")
if not self.autograd_available:
raise ValueError("Backend does not support auto-differentiation.")

return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj)

def compute_density(self, method=None):
"""Compute the density compensation weights and set it.
Expand Down Expand Up @@ -492,9 +500,9 @@ def __repr__(self):
)

@classmethod
def with_autograd(cls, variable, *args, **kwargs):
def with_autograd(cls, wrt_data=True, wrt_traj=False, *args, **kwargs):
"""Return a Fourier operator with autograd capabilities."""
return cls(*args, **kwargs).make_autograd(variable)
return cls(*args, **kwargs).make_autograd(wrt_data, wrt_traj)


class FourierOperatorCPU(FourierOperatorBase):
Expand Down
29 changes: 27 additions & 2 deletions src/mrinufft/operators/interfaces/cufinufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
# the first element is dummy to index type 1 with 1
# and type 2 with 2.
self.plans = [None, None, None]
self.grad_plan = None

for i in [1, 2]:
self._make_plan(i, **kwargs)
Expand All @@ -89,8 +90,21 @@ def _make_plan(self, typ, **kwargs):
**kwargs,
)

def _make_plan_grad(self, **kwargs):
self.grad_plan = Plan(
2,
self.shape,
self.n_trans,
self.eps,
dtype=DTYPE_R2C[str(self.samples.dtype)],
isign=1,
**kwargs,
)
self._set_pts(typ="grad")

def _set_pts(self, typ):
self.plans[typ].setpts(
plan = self.grad_plan if typ == "grad" else self.plans[typ]
plan.setpts(
cp.array(self.samples[:, 0], copy=False),
cp.array(self.samples[:, 1], copy=False),
cp.array(self.samples[:, 2], copy=False) if self.ndim == 3 else None,
Expand All @@ -102,6 +116,12 @@ def _destroy_plan(self, typ):
del p
self.plans[typ] = None

def _destroy_plan_grad(self):
if self.grad_plan is not None:
p = self.grad_plan
del p
self.grad_plan = None

def type1(self, coeff_data, grid_data):
"""Type 1 transform. Non Uniform to Uniform."""
return self.plans[1].execute(coeff_data, grid_data)
Expand All @@ -110,6 +130,10 @@ def type2(self, grid_data, coeff_data):
"""Type 2 transform. Uniform to non-uniform."""
return self.plans[2].execute(grid_data, coeff_data)

def toggle_grad_traj(self):
"""Toggle between the gradient trajectory and the plan for type 1 transform."""
self.plans[2], self.grad_plan = self.grad_plan, self.plans[2]


class MRICufiNUFFT(FourierOperatorBase):
"""MRI Transform operator, build around cufinufft.
Expand Down Expand Up @@ -165,6 +189,7 @@ class MRICufiNUFFT(FourierOperatorBase):

backend = "cufinufft"
available = CUFINUFFT_AVAILABLE and CUPY_AVAILABLE
autograd_available = True

def __init__(
self,
Expand Down Expand Up @@ -195,12 +220,12 @@ def __init__(
self.n_trans = n_trans
self.squeeze_dims = squeeze_dims
self.n_coils = n_coils
self.autograd_available = True
# For now only single precision is supported
self.samples = np.asfortranarray(
proper_trajectory(samples, normalize="pi").astype(np.float32, copy=False)
)
self.dtype = self.samples.dtype

# density compensation support
if is_cuda_array(density):
self.density = density
Expand Down
Loading
Loading