-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Autodiff wrt to trajectory #116
Conversation
@matthieutrs if you want to have a look ;) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for great work! Not yet done with review, but a couple of issues to fix
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple of things to be discussed generally @paquiteau
@paquiteau I handled the minor details.. locally my tests are passing. Ill add you to officially review it in free time. @alineyyy will handle your review comments. I shall try to see if I can support this in gpuNUFFT in meantime. Great job everyone! |
Ahhh I forgot, @alineyyy you need to do the same updates to finufft interface... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Adding @paquiteau for some final checks and then we can merge when green.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a major a milestone for MRI-NUFFT, and we are never been so close of merging. Thanks @alineyyy for making this possible !
Some comments/questions to address, and a few suggestion for making things even better.
src/mrinufft/operators/autodiff.py
Outdated
(x,) = ctx.saved_tensors | ||
return ctx.nufft_op.adj_op(dy), None | ||
(x, traj) = ctx.saved_tensors | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grad_data = None | |
grad_traj = None |
src/mrinufft/operators/autodiff.py
Outdated
else: | ||
grad_traj = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else: | |
grad_traj = None |
from ..base import FourierOperatorCPU | ||
from mrinufft._utils import proper_trajectory | ||
from mrinufft._utils import proper_trajectory, get_array_module | ||
|
||
|
||
def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): | ||
"""Get the NDFT Fourier Matrix.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some documentation as well if possible ?
src/mrinufft/operators/base.py
Outdated
if not AUTOGRAD_AVAILABLE or not self.autograd_available: | ||
raise ValueError("Autograd not available, ensure torch is installed.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if not AUTOGRAD_AVAILABLE or not self.autograd_available: | |
raise ValueError("Autograd not available, ensure torch is installed.") | |
if not AUTOGRAD_AVAILABLE: | |
raise ValueError("Autograd not available, ensure torch is installed.") | |
if not self.autograd_available: | |
raise ValueError("Backend does not support auto-differentiation.") |
self.shape, | ||
self.n_trans, | ||
self.eps, | ||
dtype="complex64" if self.samples.dtype == "float32" else "complex128", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can use DTYPE_R2C
dict for that
matrix = ( | ||
xp.exp(-2j * xp.pi * traj_grid).to(dtype).to(device).clone() | ||
if xp.__name__ == "torch" | ||
else (xp.exp(-2j * xp.pi * traj_grid, dtype=dtype)) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
matrix = ( | |
xp.exp(-2j * xp.pi * traj_grid).to(dtype).to(device).clone() | |
if xp.__name__ == "torch" | |
else (xp.exp(-2j * xp.pi * traj_grid, dtype=dtype)) | |
) | |
matrix = xp.exp(-2j * xp.pi * traj_grid) | |
if xp.__name__ == "torch": | |
matrix.to(dtype=dtype, device=device, copy=True) | |
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I have to use matrix = matrix.to(...) here, otherwise it will occur an error in the autograd function of pytorch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
almost there, there are a few comments to address still.
If the failing gpunufft test is too strict causing problems, just ensure the gradients are close and maybe just make the test less strict |
This resolves #103