From 0138abdce49ea84ca4fabf4647e401b02677ff0e Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Fri, 24 May 2024 15:30:04 +0200 Subject: [PATCH 01/45] update for the --- src/mrinufft/operators/autodiff.py | 45 +++++++++++++--- tests/test_autodiff.py | 83 ++++++++++++++++++++++++++---- 2 files changed, 113 insertions(+), 15 deletions(-) diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index 155fb485..c6b674c7 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -3,21 +3,54 @@ import torch +# class _NUFFT_OP(torch.autograd.Function): +# """Autograd support for op nufft function.""" + +# @staticmethod +# def forward(ctx, x, nufft_op): +# """Forward image -> k-space.""" +# ctx.save_for_backward(x) +# ctx.nufft_op = nufft_op +# return nufft_op.op(x) + +# @staticmethod +# def backward(ctx, dy): +# """Backward image -> k-space.""" +# (x,) = ctx.saved_tensors + +# return ctx.nufft_op.adj_op(dy), None + class _NUFFT_OP(torch.autograd.Function): """Autograd support for op nufft function.""" @staticmethod - def forward(ctx, x, nufft_op): + def forward(ctx, x, nufft_op, traj): """Forward image -> k-space.""" - ctx.save_for_backward(x) + ctx.save_for_backward(x, traj) ctx.nufft_op = nufft_op return nufft_op.op(x) + @staticmethod def backward(ctx, dy): """Backward image -> k-space.""" - (x,) = ctx.saved_tensors - return ctx.nufft_op.adj_op(dy), None + print(dy.shape) + (x,traj) = ctx.saved_tensors + + im_size = x.size()[1:] #[16, 16] + r = [torch.linspace(-size / 2, size / 2 - 1, size) for size in im_size] #len(r) = 2 / [16] for each + grid_r = torch.meshgrid(*r, indexing='ij') + grid_r = torch.stack(grid_r, dim=0).type_as(x)[None, ...]# add batch size [1, 2, 16, 16] + + grid_x = x * grid_r + nufft_dx_dom = torch.cat([ctx.nufft_op.op(grid_x[:, i:i+1, :, :]) for i in range(grid_x.size(1))], dim=1) + #nufft_dx_dom = ctx.nufft_op.op(x * grid_r) # not work beacuse op only accpect [1, 1, *im_size] + + grad_traj = torch.transpose((-1j * torch.conj(dy) * nufft_dx_dom).squeeze(), 0, 1).type_as(traj) #dy should be [1, 1, 256] nufft_dx_dom should be [1, 2, 256] the first dim is batch size which should be reserved for the nufft + + + + return ctx.nufft_op.adj_op(dy), None, grad_traj class _NUFFT_ADJOP(torch.autograd.Function): @@ -52,9 +85,9 @@ def __init__(self, nufft_op): raise ValueError("Squeezing dimensions is not " "supported for autodiff.") self.nufft_op = nufft_op - def op(self, x): + def op(self, x, traj): r"""Compute the forward image -> k-space.""" - return _NUFFT_OP.apply(x, self.nufft_op) + return _NUFFT_OP.apply(x, self.nufft_op, traj) def adj_op(self, kspace): r"""Compute the adjoint k-space -> image.""" diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index 3b24a25d..99ff2418 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -6,7 +6,7 @@ from pytest_cases import parametrize_with_cases, parametrize, fixture from case_trajectories import CasesTrajectories from mrinufft.operators import get_operator - +import warnings from helpers import ( kspace_from_op, @@ -46,10 +46,61 @@ def operator(kspace_loc, shape, backend, autograd): return nufft +def proper_trajectory_torch(trajectory, normalize="pi"): + + if not torch.is_tensor(trajectory): + raise ValueError("trajectory should be a torch.Tensor") + + new_traj = trajectory.clone() + new_traj = new_traj.view(-1, trajectory.shape[-1]) + + if normalize == "pi" and torch.max(torch.abs(new_traj)) - 1e-4 < 0.5: + warnings.warn( + "Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)" + ) + new_traj *= 2 * torch.pi + elif normalize == "unit" and torch.max(torch.abs(new_traj)) - 1e-4 > 0.5: + warnings.warn( + "Samples will be rescaled to [-0.5, 0.5), assuming they were in [-pi, pi)" + ) + new_traj /= 2 * torch.pi + + if normalize == "unit" and torch.max(new_traj) >= 0.5: + new_traj = (new_traj + 0.5) % 1 - 0.5 + + return new_traj + +# We need the calculation of the NDFT matrix to be done in the GPU/CPU, so I rewrite this function to calculate it on GPU/CPU (it's for grad_traj backpropagation, the +# whole things about trajectory should be done with required_grad=True which needs everthing to be put on the same device) +def get_fourier_matrix_torch(ktraj, shape, dtype=torch.complex64, normalize=False): + """Get the NDFT Fourier Matrix.""" + device = ktraj.device + ktraj = proper_trajectory_torch(ktraj, normalize="unit") + n = np.prod(shape) + ndim = len(shape) + + r = [torch.linspace(-s / 2, s / 2 - 1, s, device=device) for s in shape] + + grid_r = torch.meshgrid(r, indexing="ij") + grid_r = torch.reshape(torch.stack(grid_r), (ndim, n)).to(device) + + traj_grid = torch.matmul(ktraj, grid_r) + matrix = torch.exp(-2j * np.pi * traj_grid).to(dtype).to(device).clone() + + if normalize: + matrix /= torch.sqrt(torch.tensor(np.prod(shape), device=device)) * torch.pow(torch.sqrt(torch.tensor(2, device=device)), ndim) + + return matrix + +def ndft_matrix_ktraj(operator, k_traj): + """Get the NDFT matrix from the operator.""" + return get_fourier_matrix_torch(k_traj, operator.shape, normalize=True) # operator.samples is trajectory + + @fixture(scope="module") def ndft_matrix(operator): """Get the NDFT matrix from the operator.""" - return get_fourier_matrix(operator.samples, operator.shape, normalize=True) + return get_fourier_matrix(operator.samples, operator.shape, normalize=True) @pytest.mark.parametrize("interface", ["torch-gpu", "torch-cpu"]) @@ -81,24 +132,38 @@ def test_adjoint_and_grad(operator, ndft_matrix, interface): @pytest.mark.parametrize("interface", ["torch-gpu", "torch-cpu"]) @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Pytorch is not installed") -def test_forward_and_grad(operator, ndft_matrix, interface): +def test_forward_and_grad(operator, interface): """Test the adjoint and gradient of the operator.""" if operator.backend == "finufft" and "gpu" in interface: pytest.skip("GPU not supported for finufft backend") - ndft_matrix_torch = to_interface(ndft_matrix, interface=interface) - ksp_data_ref = to_interface(kspace_from_op(operator), interface=interface) - img_data = to_interface(image_from_op(operator), interface=interface) + + ktraj = to_interface(np.copy(operator.samples), interface = interface) + + ksp_data_ref = to_interface(kspace_from_op(operator), interface=interface) #[1, 256] + img_data = to_interface(image_from_op(operator), interface=interface) #[1, 16, 16] 1 is num of coil + img_data.requires_grad = True + ktraj.requires_grad = True with torch.autograd.set_detect_anomaly(True): - ksp_data = operator.op(img_data).reshape(ksp_data_ref.shape) - ksp_data_ndft = (ndft_matrix_torch @ img_data.flatten()).reshape(ksp_data.shape) + ksp_data = operator.op(img_data, ktraj).reshape(ksp_data_ref.shape) #[1, 1, 256] => [1, 256] + ndft_matrix_torch = ndft_matrix_ktraj(operator,ktraj) + ksp_data_ndft = (ndft_matrix_torch @ img_data.flatten()).reshape(ksp_data.shape) loss_nufft = torch.mean(torch.abs(ksp_data - ksp_data_ref) ** 2) loss_ndft = torch.mean(torch.abs(ksp_data_ndft - ksp_data_ref) ** 2) + + # Check if nufft and ndft w.r.t trajectory are close in the backprop + gradient_ndft_ktraj = torch.autograd.grad(loss_ndft, ktraj, retain_graph=True)[0] + gradient_nufft_ktraj= torch.autograd.grad(loss_nufft, ktraj, retain_graph=True)[ #d dy = dL / ksp_data, dy should be [1, 256] + 0 + ] + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=6e-3) + + # Check if nufft and ndft are close in the backprop - gradient_ndft_kdata = torch.autograd.grad(loss_ndft, img_data, retain_graph=True)[0] + gradient_ndft_kdata = torch.autograd.grad(loss_ndft, img_data, retain_graph=True)[0] gradient_nufft_kdata = torch.autograd.grad(loss_nufft, img_data, retain_graph=True)[ 0 ] From f2bbe52e5335505c8f8bbe4b518ac70f24654859 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Fri, 24 May 2024 15:41:46 +0200 Subject: [PATCH 02/45] change the position of the parameters in autodiff --- src/mrinufft/operators/autodiff.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index c6b674c7..61f9553f 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -24,7 +24,7 @@ class _NUFFT_OP(torch.autograd.Function): """Autograd support for op nufft function.""" @staticmethod - def forward(ctx, x, nufft_op, traj): + def forward(ctx, x, traj, nufft_op): #FIXME: change the position """Forward image -> k-space.""" ctx.save_for_backward(x, traj) ctx.nufft_op = nufft_op @@ -50,7 +50,7 @@ def backward(ctx, dy): - return ctx.nufft_op.adj_op(dy), None, grad_traj + return ctx.nufft_op.adj_op(dy), grad_traj, None class _NUFFT_ADJOP(torch.autograd.Function): @@ -87,7 +87,7 @@ def __init__(self, nufft_op): def op(self, x, traj): r"""Compute the forward image -> k-space.""" - return _NUFFT_OP.apply(x, self.nufft_op, traj) + return _NUFFT_OP.apply(x, traj, self.nufft_op ) def adj_op(self, kspace): r"""Compute the adjoint k-space -> image.""" From 2a4c2e13561fec291caae1422a44e80c288da560 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Fri, 24 May 2024 17:55:03 +0200 Subject: [PATCH 03/45] change the positionof the functions and the comments --- src/mrinufft/operators/autodiff.py | 65 ++++++++++++++---------------- tests/test_autodiff.py | 49 +++++++++++----------- 2 files changed, 55 insertions(+), 59 deletions(-) diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index 61f9553f..e68a2dbd 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -3,54 +3,51 @@ import torch -# class _NUFFT_OP(torch.autograd.Function): -# """Autograd support for op nufft function.""" - -# @staticmethod -# def forward(ctx, x, nufft_op): -# """Forward image -> k-space.""" -# ctx.save_for_backward(x) -# ctx.nufft_op = nufft_op -# return nufft_op.op(x) - -# @staticmethod -# def backward(ctx, dy): -# """Backward image -> k-space.""" -# (x,) = ctx.saved_tensors - -# return ctx.nufft_op.adj_op(dy), None - class _NUFFT_OP(torch.autograd.Function): - """Autograd support for op nufft function.""" + """ + Autograd support for op nufft function. + + This class is implemented by an efficient approximation of Jacobian Matrices. + + References: + ----------- + Wang G, Fessler J A. "Efficient approximation of Jacobian matrices involving a non-uniform fast Fourier transform (NUFFT)." + IEEE Transactions on Computational Imaging, 2023, 9: 43-54. + """ @staticmethod - def forward(ctx, x, traj, nufft_op): #FIXME: change the position + def forward(ctx, x, traj, nufft_op): """Forward image -> k-space.""" ctx.save_for_backward(x, traj) ctx.nufft_op = nufft_op return nufft_op.op(x) - @staticmethod def backward(ctx, dy): """Backward image -> k-space.""" - print(dy.shape) - (x,traj) = ctx.saved_tensors + (x, traj) = ctx.saved_tensors - im_size = x.size()[1:] #[16, 16] - r = [torch.linspace(-size / 2, size / 2 - 1, size) for size in im_size] #len(r) = 2 / [16] for each - grid_r = torch.meshgrid(*r, indexing='ij') - grid_r = torch.stack(grid_r, dim=0).type_as(x)[None, ...]# add batch size [1, 2, 16, 16] + im_size = x.size()[1:] + r = [torch.linspace(-size / 2, size / 2 - 1, size) for size in im_size] + grid_r = torch.meshgrid(*r, indexing="ij") + grid_r = torch.stack(grid_r, dim=0).type_as(x)[None, ...] - grid_x = x * grid_r - nufft_dx_dom = torch.cat([ctx.nufft_op.op(grid_x[:, i:i+1, :, :]) for i in range(grid_x.size(1))], dim=1) - #nufft_dx_dom = ctx.nufft_op.op(x * grid_r) # not work beacuse op only accpect [1, 1, *im_size] - - grad_traj = torch.transpose((-1j * torch.conj(dy) * nufft_dx_dom).squeeze(), 0, 1).type_as(traj) #dy should be [1, 1, 256] nufft_dx_dom should be [1, 2, 256] the first dim is batch size which should be reserved for the nufft - + grid_x = x * grid_r # Element-wise multiplication: x * r + nufft_dx_dom = torch.cat( + [ + ctx.nufft_op.op(grid_x[:, i : i + 1, :, :]) + for i in range(grid_x.size(1)) + ], + dim=1, + ) # Compute A(x * r) for each channel and concatenate along the channel dimension + grad_traj = torch.transpose( + (-1j * torch.conj(dy) * nufft_dx_dom).squeeze(), 0, 1 + ).type_as( + traj + ) # Compute gradient with respect to trajectory: -i * dy' * A(x * r) - return ctx.nufft_op.adj_op(dy), grad_traj, None + return ctx.nufft_op.adj_op(dy), grad_traj, None class _NUFFT_ADJOP(torch.autograd.Function): @@ -87,7 +84,7 @@ def __init__(self, nufft_op): def op(self, x, traj): r"""Compute the forward image -> k-space.""" - return _NUFFT_OP.apply(x, traj, self.nufft_op ) + return _NUFFT_OP.apply(x, traj, self.nufft_op) def adj_op(self, kspace): r"""Compute the adjoint k-space -> image.""" diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index 99ff2418..a052c29f 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -23,7 +23,7 @@ @fixture(scope="module") -@parametrize(backend=["cufinufft", "finufft"]) +@parametrize(backend=["cufinufft", "finufft", "gpunufft"]) @parametrize(autograd=["data"]) @parametrize_with_cases( "kspace_loc, shape", @@ -47,7 +47,7 @@ def operator(kspace_loc, shape, backend, autograd): def proper_trajectory_torch(trajectory, normalize="pi"): - + """Normalize the trajectory to be used by NUFFT operators on device.""" if not torch.is_tensor(trajectory): raise ValueError("trajectory should be a torch.Tensor") @@ -70,37 +70,39 @@ def proper_trajectory_torch(trajectory, normalize="pi"): return new_traj -# We need the calculation of the NDFT matrix to be done in the GPU/CPU, so I rewrite this function to calculate it on GPU/CPU (it's for grad_traj backpropagation, the -# whole things about trajectory should be done with required_grad=True which needs everthing to be put on the same device) + def get_fourier_matrix_torch(ktraj, shape, dtype=torch.complex64, normalize=False): - """Get the NDFT Fourier Matrix.""" + """Get the NDFT Fourier Matrix which is calculated on device.""" device = ktraj.device ktraj = proper_trajectory_torch(ktraj, normalize="unit") n = np.prod(shape) ndim = len(shape) - + r = [torch.linspace(-s / 2, s / 2 - 1, s, device=device) for s in shape] - + grid_r = torch.meshgrid(r, indexing="ij") grid_r = torch.reshape(torch.stack(grid_r), (ndim, n)).to(device) - + traj_grid = torch.matmul(ktraj, grid_r) matrix = torch.exp(-2j * np.pi * traj_grid).to(dtype).to(device).clone() - + if normalize: - matrix /= torch.sqrt(torch.tensor(np.prod(shape), device=device)) * torch.pow(torch.sqrt(torch.tensor(2, device=device)), ndim) - + matrix /= torch.sqrt(torch.tensor(np.prod(shape), device=device)) * torch.pow( + torch.sqrt(torch.tensor(2, device=device)), ndim + ) + return matrix + def ndft_matrix_ktraj(operator, k_traj): """Get the NDFT matrix from the operator.""" - return get_fourier_matrix_torch(k_traj, operator.shape, normalize=True) # operator.samples is trajectory + return get_fourier_matrix_torch(k_traj, operator.shape, normalize=True) @fixture(scope="module") def ndft_matrix(operator): """Get the NDFT matrix from the operator.""" - return get_fourier_matrix(operator.samples, operator.shape, normalize=True) + return get_fourier_matrix(operator.samples, operator.shape, normalize=True) @pytest.mark.parametrize("interface", ["torch-gpu", "torch-cpu"]) @@ -137,33 +139,30 @@ def test_forward_and_grad(operator, interface): if operator.backend == "finufft" and "gpu" in interface: pytest.skip("GPU not supported for finufft backend") + ktraj = to_interface(np.copy(operator.samples), interface=interface) - ktraj = to_interface(np.copy(operator.samples), interface = interface) - - ksp_data_ref = to_interface(kspace_from_op(operator), interface=interface) #[1, 256] - img_data = to_interface(image_from_op(operator), interface=interface) #[1, 16, 16] 1 is num of coil + ksp_data_ref = to_interface(kspace_from_op(operator), interface=interface) + img_data = to_interface(image_from_op(operator), interface=interface) img_data.requires_grad = True ktraj.requires_grad = True with torch.autograd.set_detect_anomaly(True): - ksp_data = operator.op(img_data, ktraj).reshape(ksp_data_ref.shape) #[1, 1, 256] => [1, 256] - ndft_matrix_torch = ndft_matrix_ktraj(operator,ktraj) - ksp_data_ndft = (ndft_matrix_torch @ img_data.flatten()).reshape(ksp_data.shape) + ksp_data = operator.op(img_data, ktraj).reshape(ksp_data_ref.shape) + ndft_matrix_torch = ndft_matrix_ktraj(operator, ktraj) + ksp_data_ndft = (ndft_matrix_torch @ img_data.flatten()).reshape(ksp_data.shape) loss_nufft = torch.mean(torch.abs(ksp_data - ksp_data_ref) ** 2) loss_ndft = torch.mean(torch.abs(ksp_data_ndft - ksp_data_ref) ** 2) - # Check if nufft and ndft w.r.t trajectory are close in the backprop - gradient_ndft_ktraj = torch.autograd.grad(loss_ndft, ktraj, retain_graph=True)[0] - gradient_nufft_ktraj= torch.autograd.grad(loss_nufft, ktraj, retain_graph=True)[ #d dy = dL / ksp_data, dy should be [1, 256] + gradient_ndft_ktraj = torch.autograd.grad(loss_ndft, ktraj, retain_graph=True)[0] + gradient_nufft_ktraj = torch.autograd.grad(loss_nufft, ktraj, retain_graph=True)[ # 0 ] assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=6e-3) - # Check if nufft and ndft are close in the backprop - gradient_ndft_kdata = torch.autograd.grad(loss_ndft, img_data, retain_graph=True)[0] + gradient_ndft_kdata = torch.autograd.grad(loss_ndft, img_data, retain_graph=True)[0] gradient_nufft_kdata = torch.autograd.grad(loss_nufft, img_data, retain_graph=True)[ 0 ] From 638f55cbc38f899ff7e28381e897afe9e113f43b Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Fri, 24 May 2024 18:07:30 +0200 Subject: [PATCH 04/45] update for checkingthe style --- src/mrinufft/operators/autodiff.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index e68a2dbd..910e6db2 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -9,9 +9,10 @@ class _NUFFT_OP(torch.autograd.Function): This class is implemented by an efficient approximation of Jacobian Matrices. - References: - ----------- - Wang G, Fessler J A. "Efficient approximation of Jacobian matrices involving a non-uniform fast Fourier transform (NUFFT)." + References + ---------- + Wang G, Fessler J A. "Efficient approximation of Jacobian matrices involving a + non-uniform fast Fourier transform (NUFFT)." IEEE Transactions on Computational Imaging, 2023, 9: 43-54. """ @@ -39,7 +40,7 @@ def backward(ctx, dy): for i in range(grid_x.size(1)) ], dim=1, - ) # Compute A(x * r) for each channel and concatenate along the channel dimension + ) # Compute A(x * r) for each channel and concatenate along this dimension grad_traj = torch.transpose( (-1j * torch.conj(dy) * nufft_dx_dom).squeeze(), 0, 1 From 11d770f76bc2a19aa03b40b1100b90a9611642ae Mon Sep 17 00:00:00 2001 From: chaithyagr Date: Mon, 27 May 2024 10:44:01 +0200 Subject: [PATCH 05/45] MINOR Extra changes, Please remove this changes --- src/mrinufft/_utils.py | 3 ++- src/mrinufft/operators/autodiff.py | 2 +- src/mrinufft/operators/interfaces/nudft_numpy.py | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mrinufft/_utils.py b/src/mrinufft/_utils.py index bcfa7540..5c0734be 100644 --- a/src/mrinufft/_utils.py +++ b/src/mrinufft/_utils.py @@ -57,7 +57,7 @@ def auto_cast(array, dtype: DTypeLike): return array.astype(dtype) -def proper_trajectory(trajectory, normalize="pi"): +def proper_trajectory(trajectory, normalize="pi"): """Normalize the trajectory to be used by NUFFT operators. Parameters @@ -75,6 +75,7 @@ def proper_trajectory(trajectory, normalize="pi"): The normalized trajectory of shape (Nc * Ns, dim) or (Ns, dim) in -pi, pi """ # flatten to a list of point + #FIXME: check if trajectory is torch. get_array_module : torch. Basically output must be torch if input is torch and everything must be done in torch. try: new_traj = np.asarray(trajectory).copy() except Exception as e: diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index 910e6db2..605745f1 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -85,7 +85,7 @@ def __init__(self, nufft_op): def op(self, x, traj): r"""Compute the forward image -> k-space.""" - return _NUFFT_OP.apply(x, traj, self.nufft_op) + return _NUFFT_OP.apply(x, self.nufft_op) def adj_op(self, kspace): r"""Compute the adjoint k-space -> image.""" diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 4dd4b585..5c6ab5ea 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -10,6 +10,7 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): + #FIXME: check if trajectory is torch. get_array_module : torch. Basically output must be torch if input is torch and everything must be done in torch. """Get the NDFT Fourier Matrix.""" ktraj = proper_trajectory(ktraj, normalize="unit") n = np.prod(shape) From 71579bc86081cf3dff1e5bad0850a14a392d2213 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Wed, 29 May 2024 18:16:32 +0200 Subject: [PATCH 06/45] update for forward part --- .vscode/launch.json | 35 ++++++ .vscode/settings.json | 6 + src/mrinufft/_utils.py | 26 +++- src/mrinufft/operators/autodiff.py | 41 +++++-- .../operators/interfaces/cufinufft.py | 5 +- .../operators/interfaces/nudft_numpy.py | 36 ++++-- tests/test_autodiff.py | 116 +++++++----------- 7 files changed, 172 insertions(+), 93 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..cb49e2bd --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,35 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + + + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "module": "pytest", + "args": [ + "/volatile/Caini/mri-nufft/tests/test_autodiff.py" + ], + "console": "integratedTerminal", + "python": "/volatile/Caini/Envs/projector/bin/python", + "env": { + "PYTHONPATH": "/mri-nufft" + }, + + }, + { + "name": "DTest", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "purpose": ["debug-test"], + "console": "integratedTerminal", + "justMyCode": false + }, + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..d23fb8e7 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "python.testing.pytestArgs": [ + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/src/mrinufft/_utils.py b/src/mrinufft/_utils.py index bcfa7540..5ae52926 100644 --- a/src/mrinufft/_utils.py +++ b/src/mrinufft/_utils.py @@ -75,25 +75,39 @@ def proper_trajectory(trajectory, normalize="pi"): The normalized trajectory of shape (Nc * Ns, dim) or (Ns, dim) in -pi, pi """ # flatten to a list of point + module = get_array_module(trajectory) try: - new_traj = np.asarray(trajectory).copy() + new_traj = ( + trajectory.clone() + if module.__name__ == "torch" + else np.asarray(trajectory).copy() + ) except Exception as e: raise ValueError( "trajectory should be array_like, with the last dimension being coordinates" ) from e + new_traj = new_traj.reshape(-1, trajectory.shape[-1]) - if normalize == "pi" and np.max(abs(new_traj)) - 1e-4 < 0.5: + max_abs_val = ( + torch.max(torch.abs(new_traj)) + if module.__name__ == "torch" + else np.max(np.abs(new_traj)) + ) + + if normalize == "pi" and max_abs_val - 1e-4 < 0.5: warnings.warn( "Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)" ) - new_traj *= 2 * np.pi - elif normalize == "unit" and np.max(abs(new_traj)) - 1e-4 > 0.5: + new_traj *= 2 * torch.pi if module.__name__ == "torch" else 2 * np.pi + elif normalize == "unit" and max_abs_val - 1e-4 > 0.5: warnings.warn( "Samples will be rescaled to [-0.5, 0.5), assuming they were in [-pi, pi)" ) - new_traj /= 2 * np.pi - if normalize == "unit" and np.max(new_traj) >= 0.5: + new_traj *= ( + 1 / (2 * torch.pi) if module.__name__ == "torch" else 1 / (2 * np.pi) + ) + if normalize == "unit" and max_abs_val >= 0.5: new_traj = (new_traj + 0.5) % 1 - 0.5 return new_traj diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index 910e6db2..23a0c9c8 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -1,6 +1,7 @@ """Torch autodifferentiation for MRI-NUFFT.""" import torch +import numpy as np class _NUFFT_OP(torch.autograd.Function): @@ -19,7 +20,7 @@ class _NUFFT_OP(torch.autograd.Function): @staticmethod def forward(ctx, x, traj, nufft_op): """Forward image -> k-space.""" - ctx.save_for_backward(x, traj) + ctx.save_for_backward(x, traj) # nufft_op.samples => traj ctx.nufft_op = nufft_op return nufft_op.op(x) @@ -55,17 +56,41 @@ class _NUFFT_ADJOP(torch.autograd.Function): """Autograd support for adj_op nufft function.""" @staticmethod - def forward(ctx, y, nufft_op): + def forward(ctx, y, traj, nufft_op): """Forward kspace -> image.""" - ctx.save_for_backward(y) + ctx.save_for_backward(y, traj) ctx.nufft_op = nufft_op return nufft_op.adj_op(y) @staticmethod def backward(ctx, dx): """Backward kspace -> image.""" - (y,) = ctx.saved_tensors - return ctx.nufft_op.op(dx), None + (y, traj) = ctx.saved_tensors # y [1, 256] traj [256, 2] + grad_traj = None + # im_size = dx.size()[2:] + # r = [torch.linspace(-size / 2, size / 2 - 1, size) for size in im_size] + # grid_r = torch.meshgrid(*r, indexing="ij") + # grid_r = torch.stack(grid_r, dim=0).type_as(dx)[None, ...] #[1, 2, 16, 16] + + # diag_y = torch.diag_embed(y) #[1, 256, 256] 想要to be [1, 256, 16, 16] + + # ifft_diag_y = torch.cat( + # [ + # ctx.nufft_op.adj_op(diag_y[:, i : i + 1, :]) + # for i in range(diag_y.size(1)) + # ], + # dim=1, + # ) # [1, 2048, 32, 32] + + # grad_traj = torch.cat( + # [ + # (dx * grid_r[:, i : i + 1, :, :] * ifft_diag_y).sum(dim=(2, 3)) + # for i in range(grid_r.size(1)) + # ] + # ).type_as(traj) + + # grad_traj = torch.transpose(grad_traj, 0, 1).type_as(traj) + return ctx.nufft_op.op(dx), grad_traj, None class MRINufftAutoGrad(torch.nn.Module): @@ -83,13 +108,13 @@ def __init__(self, nufft_op): raise ValueError("Squeezing dimensions is not " "supported for autodiff.") self.nufft_op = nufft_op - def op(self, x, traj): + def op(self, x): r"""Compute the forward image -> k-space.""" - return _NUFFT_OP.apply(x, traj, self.nufft_op) + return _NUFFT_OP.apply(x, self.samples, self.nufft_op) def adj_op(self, kspace): r"""Compute the adjoint k-space -> image.""" - return _NUFFT_ADJOP.apply(kspace, self.nufft_op) + return _NUFFT_ADJOP.apply(kspace, self.samples, self.nufft_op) def __getattr__(self, name): """Get the attribute from the root operator.""" diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 3d512127..ece1c253 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -1,5 +1,5 @@ """Provides Operator for MR Image processing on GPU.""" - +import torch import warnings import numpy as np from mrinufft.operators.base import FourierOperatorBase, with_numpy_cupy @@ -88,6 +88,7 @@ def _make_plan(self, typ, **kwargs): dtype=DTYPE_R2C[str(self.samples.dtype)], **kwargs, ) + def _set_pts(self, typ): self.plans[typ].setpts( @@ -178,6 +179,7 @@ def __init__( verbose=False, squeeze_dims=False, n_trans=1, + grad_traj=False, **kwargs, ): # run the availaility check here to get detailled output. @@ -200,7 +202,6 @@ def __init__( proper_trajectory(samples, normalize="pi").astype(np.float32) ) self.dtype = self.samples.dtype - # density compensation support if is_cuda_array(density): self.density = density diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 4dd4b585..b9263e8d 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -4,23 +4,43 @@ import numpy as np import scipy as sp - +import torch from ..base import FourierOperatorCPU -from mrinufft._utils import proper_trajectory +from mrinufft._utils import proper_trajectory, get_array_module def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): """Get the NDFT Fourier Matrix.""" + module = get_array_module(ktraj) ktraj = proper_trajectory(ktraj, normalize="unit") n = np.prod(shape) ndim = len(shape) - matrix = np.zeros((len(ktraj), n), dtype=dtype) - r = [np.linspace(-s / 2, s / 2 - 1, s) for s in shape] - grid_r = np.reshape(np.meshgrid(*r, indexing="ij"), (ndim, np.prod(shape))) - traj_grid = ktraj @ grid_r - matrix = np.exp(-2j * np.pi * traj_grid, dtype=dtype) + + if module.__name__ == "torch": + device = ktraj.device + dtype = torch.complex64 + r = [torch.linspace(-s / 2, s / 2 - 1, s, device=device) for s in shape] + grid_r = torch.meshgrid(r, indexing="ij") + grid_r = torch.reshape(torch.stack(grid_r), (ndim, n)).to(device) + traj_grid = torch.matmul(ktraj, grid_r) + matrix = torch.exp(-2j * np.pi * traj_grid).to(dtype).to(device).clone() + + else: + r = [np.linspace(-s / 2, s / 2 - 1, s) for s in shape] + grid_r = np.reshape(np.meshgrid(*r, indexing="ij"), (ndim, np.prod(shape))) + traj_grid = ktraj @ grid_r + matrix = np.exp(-2j * np.pi * traj_grid, dtype=dtype) + if normalize: - matrix /= np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape)) + matrix /= ( + ( + torch.sqrt(torch.tensor(np.prod(shape), device=device)) + * torch.pow(torch.sqrt(torch.tensor(2, device=device)), ndim) + ) + if module.__name__ == "torch" + else (np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape))) + ) + return matrix diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index a052c29f..8f3cec4a 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -7,6 +7,8 @@ from case_trajectories import CasesTrajectories from mrinufft.operators import get_operator import warnings +from mrinufft._utils import proper_trajectory +import matplotlib.pyplot as plt from helpers import ( kspace_from_op, @@ -35,7 +37,6 @@ def operator(kspace_loc, shape, backend, autograd): """Create NUFFT operator with autodiff capabilities.""" kspace_loc = kspace_loc.astype(np.float32) - nufft = get_operator(backend_name=backend, autograd=autograd)( samples=kspace_loc, shape=shape, @@ -46,60 +47,7 @@ def operator(kspace_loc, shape, backend, autograd): return nufft -def proper_trajectory_torch(trajectory, normalize="pi"): - """Normalize the trajectory to be used by NUFFT operators on device.""" - if not torch.is_tensor(trajectory): - raise ValueError("trajectory should be a torch.Tensor") - - new_traj = trajectory.clone() - new_traj = new_traj.view(-1, trajectory.shape[-1]) - - if normalize == "pi" and torch.max(torch.abs(new_traj)) - 1e-4 < 0.5: - warnings.warn( - "Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)" - ) - new_traj *= 2 * torch.pi - elif normalize == "unit" and torch.max(torch.abs(new_traj)) - 1e-4 > 0.5: - warnings.warn( - "Samples will be rescaled to [-0.5, 0.5), assuming they were in [-pi, pi)" - ) - new_traj /= 2 * torch.pi - - if normalize == "unit" and torch.max(new_traj) >= 0.5: - new_traj = (new_traj + 0.5) % 1 - 0.5 - - return new_traj - - -def get_fourier_matrix_torch(ktraj, shape, dtype=torch.complex64, normalize=False): - """Get the NDFT Fourier Matrix which is calculated on device.""" - device = ktraj.device - ktraj = proper_trajectory_torch(ktraj, normalize="unit") - n = np.prod(shape) - ndim = len(shape) - - r = [torch.linspace(-s / 2, s / 2 - 1, s, device=device) for s in shape] - - grid_r = torch.meshgrid(r, indexing="ij") - grid_r = torch.reshape(torch.stack(grid_r), (ndim, n)).to(device) - - traj_grid = torch.matmul(ktraj, grid_r) - matrix = torch.exp(-2j * np.pi * traj_grid).to(dtype).to(device).clone() - - if normalize: - matrix /= torch.sqrt(torch.tensor(np.prod(shape), device=device)) * torch.pow( - torch.sqrt(torch.tensor(2, device=device)), ndim - ) - - return matrix - - -def ndft_matrix_ktraj(operator, k_traj): - """Get the NDFT matrix from the operator.""" - return get_fourier_matrix_torch(k_traj, operator.shape, normalize=True) - - -@fixture(scope="module") +# @fixture(scope="module") def ndft_matrix(operator): """Get the NDFT matrix from the operator.""" return get_fourier_matrix(operator.samples, operator.shape, normalize=True) @@ -107,23 +55,44 @@ def ndft_matrix(operator): @pytest.mark.parametrize("interface", ["torch-gpu", "torch-cpu"]) @pytest.mark.skipif(not TORCH_AVAILABLE, reason="Pytorch is not installed") -def test_adjoint_and_grad(operator, ndft_matrix, interface): +def test_adjoint_and_grad(operator, interface): """Test the adjoint and gradient of the operator.""" if operator.backend == "finufft" and "gpu" in interface: pytest.skip("GPU not supported for finufft backend") - ndft_matrix_torch = to_interface(ndft_matrix, interface=interface) - ksp_data = to_interface(kspace_from_op(operator), interface=interface) - img_data = to_interface(image_from_op(operator), interface=interface) + + if torch.is_tensor(operator.samples): + operator.samples = operator.samples.cpu().detach().numpy() + + operator.samples = to_interface( + operator.samples, interface=interface + ) # [2048, 2] more samples!!!! + ksp_data = to_interface( + kspace_from_op(operator), interface=interface + ) # y [1, 2048] + img_data = to_interface(image_from_op(operator), interface=interface) # [1, 32, 32] ksp_data.requires_grad = True + operator.samples.requires_grad = True with torch.autograd.set_detect_anomaly(True): - adj_data = operator.adj_op(ksp_data).reshape(img_data.shape) - adj_data_ndft = (ndft_matrix_torch.conj().T @ ksp_data.flatten()).reshape( + adj_data = operator.adj_op(ksp_data).reshape(img_data.shape) # [ 1, 16, 16] + adj_data_ndft = (ndft_matrix(operator).conj().T @ ksp_data.flatten()).reshape( adj_data.shape ) loss_nufft = torch.mean(torch.abs(adj_data - img_data) ** 2) loss_ndft = torch.mean(torch.abs(adj_data_ndft - img_data) ** 2) + # Check if nufft and ndft w.r.t trajectory are close in the backprop + gradient_ndft_ktraj = torch.autograd.grad( + loss_ndft, operator.samples, retain_graph=True + )[0] + gradient_nufft_ktraj = torch.autograd.grad( + loss_nufft, operator.samples, retain_graph=True + )[0] + + assert torch.allclose( + gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-2 + ) # FIXME: plot the gradient + # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, ksp_data, retain_graph=True)[0] gradient_nufft_kdata = torch.autograd.grad(loss_nufft, ksp_data, retain_graph=True)[ @@ -139,27 +108,36 @@ def test_forward_and_grad(operator, interface): if operator.backend == "finufft" and "gpu" in interface: pytest.skip("GPU not supported for finufft backend") - ktraj = to_interface(np.copy(operator.samples), interface=interface) + if torch.is_tensor(operator.samples): + operator.samples = operator.samples.cpu().detach().numpy() + operator.samples = to_interface(operator.samples, interface=interface) ksp_data_ref = to_interface(kspace_from_op(operator), interface=interface) img_data = to_interface(image_from_op(operator), interface=interface) img_data.requires_grad = True - ktraj.requires_grad = True + operator.samples.requires_grad = True with torch.autograd.set_detect_anomaly(True): - ksp_data = operator.op(img_data, ktraj).reshape(ksp_data_ref.shape) - ndft_matrix_torch = ndft_matrix_ktraj(operator, ktraj) - ksp_data_ndft = (ndft_matrix_torch @ img_data.flatten()).reshape(ksp_data.shape) + ksp_data = operator.op(img_data).reshape(ksp_data_ref.shape) + ksp_data_ndft = (ndft_matrix(operator) @ img_data.flatten()).reshape( + ksp_data.shape + ) loss_nufft = torch.mean(torch.abs(ksp_data - ksp_data_ref) ** 2) loss_ndft = torch.mean(torch.abs(ksp_data_ndft - ksp_data_ref) ** 2) # Check if nufft and ndft w.r.t trajectory are close in the backprop - gradient_ndft_ktraj = torch.autograd.grad(loss_ndft, ktraj, retain_graph=True)[0] - gradient_nufft_ktraj = torch.autograd.grad(loss_nufft, ktraj, retain_graph=True)[ # + gradient_ndft_ktraj = torch.autograd.grad( + loss_ndft, operator.samples, retain_graph=True + )[ + 0 + ] # ktraj就是operator.samples实际上 + gradient_nufft_ktraj = torch.autograd.grad( + loss_nufft, operator.samples, retain_graph=True + )[ # 0 ] - assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=6e-3) + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-7) # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, img_data, retain_graph=True)[0] From a19f242a981a153444c4d34dcc94b523f9eb2db4 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Wed, 29 May 2024 21:26:45 +0200 Subject: [PATCH 07/45] update for forward --- src/mrinufft/_utils.py | 2 +- src/mrinufft/operators/autodiff.py | 23 --------------- .../operators/interfaces/nudft_numpy.py | 1 + tests/test_autodiff.py | 29 +++++++++---------- 4 files changed, 15 insertions(+), 40 deletions(-) diff --git a/src/mrinufft/_utils.py b/src/mrinufft/_utils.py index 5ae52926..4e842568 100644 --- a/src/mrinufft/_utils.py +++ b/src/mrinufft/_utils.py @@ -75,7 +75,7 @@ def proper_trajectory(trajectory, normalize="pi"): The normalized trajectory of shape (Nc * Ns, dim) or (Ns, dim) in -pi, pi """ # flatten to a list of point - module = get_array_module(trajectory) + module = get_array_module(trajectory) # check if the trajectory is a tensor try: new_traj = ( trajectory.clone() diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index 23a0c9c8..8d2842a0 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -67,29 +67,6 @@ def backward(ctx, dx): """Backward kspace -> image.""" (y, traj) = ctx.saved_tensors # y [1, 256] traj [256, 2] grad_traj = None - # im_size = dx.size()[2:] - # r = [torch.linspace(-size / 2, size / 2 - 1, size) for size in im_size] - # grid_r = torch.meshgrid(*r, indexing="ij") - # grid_r = torch.stack(grid_r, dim=0).type_as(dx)[None, ...] #[1, 2, 16, 16] - - # diag_y = torch.diag_embed(y) #[1, 256, 256] 想要to be [1, 256, 16, 16] - - # ifft_diag_y = torch.cat( - # [ - # ctx.nufft_op.adj_op(diag_y[:, i : i + 1, :]) - # for i in range(diag_y.size(1)) - # ], - # dim=1, - # ) # [1, 2048, 32, 32] - - # grad_traj = torch.cat( - # [ - # (dx * grid_r[:, i : i + 1, :, :] * ifft_diag_y).sum(dim=(2, 3)) - # for i in range(grid_r.size(1)) - # ] - # ).type_as(traj) - - # grad_traj = torch.transpose(grad_traj, 0, 1).type_as(traj) return ctx.nufft_op.op(dx), grad_traj, None diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index b9263e8d..eba7788a 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -11,6 +11,7 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): """Get the NDFT Fourier Matrix.""" + module = get_array_module(ktraj) ktraj = proper_trajectory(ktraj, normalize="unit") n = np.prod(shape) diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index 8f3cec4a..70e892cf 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -63,18 +63,14 @@ def test_adjoint_and_grad(operator, interface): if torch.is_tensor(operator.samples): operator.samples = operator.samples.cpu().detach().numpy() - operator.samples = to_interface( - operator.samples, interface=interface - ) # [2048, 2] more samples!!!! - ksp_data = to_interface( - kspace_from_op(operator), interface=interface - ) # y [1, 2048] - img_data = to_interface(image_from_op(operator), interface=interface) # [1, 32, 32] + operator.samples = to_interface(operator.samples, interface=interface) + ksp_data = to_interface(kspace_from_op(operator), interface=interface) + img_data = to_interface(image_from_op(operator), interface=interface) ksp_data.requires_grad = True operator.samples.requires_grad = True with torch.autograd.set_detect_anomaly(True): - adj_data = operator.adj_op(ksp_data).reshape(img_data.shape) # [ 1, 16, 16] + adj_data = operator.adj_op(ksp_data).reshape(img_data.shape) adj_data_ndft = (ndft_matrix(operator).conj().T @ ksp_data.flatten()).reshape( adj_data.shape ) @@ -84,14 +80,17 @@ def test_adjoint_and_grad(operator, interface): # Check if nufft and ndft w.r.t trajectory are close in the backprop gradient_ndft_ktraj = torch.autograd.grad( loss_ndft, operator.samples, retain_graph=True - )[0] + )[ + 0 + ] gradient_nufft_ktraj = torch.autograd.grad( loss_nufft, operator.samples, retain_graph=True - )[0] + )[ + 0 + ] - assert torch.allclose( - gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-2 - ) # FIXME: plot the gradient + + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-2) #FIXME: plot the gradient # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, ksp_data, retain_graph=True)[0] @@ -120,9 +119,7 @@ def test_forward_and_grad(operator, interface): with torch.autograd.set_detect_anomaly(True): ksp_data = operator.op(img_data).reshape(ksp_data_ref.shape) - ksp_data_ndft = (ndft_matrix(operator) @ img_data.flatten()).reshape( - ksp_data.shape - ) + ksp_data_ndft = (ndft_matrix(operator) @ img_data.flatten()).reshape(ksp_data.shape) loss_nufft = torch.mean(torch.abs(ksp_data - ksp_data_ref) ** 2) loss_ndft = torch.mean(torch.abs(ksp_data_ndft - ksp_data_ref) ** 2) From cd70dbfc507a44e087a84f8f4bc98009a6125eac Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Wed, 29 May 2024 21:57:58 +0200 Subject: [PATCH 08/45] update for forward --- src/mrinufft/_utils.py | 4 ---- src/mrinufft/operators/autodiff.py | 4 ---- src/mrinufft/operators/interfaces/nudft_numpy.py | 2 -- 3 files changed, 10 deletions(-) diff --git a/src/mrinufft/_utils.py b/src/mrinufft/_utils.py index 011cccb9..0a57c60a 100644 --- a/src/mrinufft/_utils.py +++ b/src/mrinufft/_utils.py @@ -75,11 +75,7 @@ def proper_trajectory(trajectory, normalize="pi"): The normalized trajectory of shape (Nc * Ns, dim) or (Ns, dim) in -pi, pi """ # flatten to a list of point -<<<<<<< HEAD module = get_array_module(trajectory) # check if the trajectory is a tensor -======= - #FIXME: check if trajectory is torch. get_array_module : torch. Basically output must be torch if input is torch and everything must be done in torch. ->>>>>>> origin/autodiff_ktraj try: new_traj = ( trajectory.clone() diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index c64ee786..8d2842a0 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -87,11 +87,7 @@ def __init__(self, nufft_op): def op(self, x): r"""Compute the forward image -> k-space.""" -<<<<<<< HEAD return _NUFFT_OP.apply(x, self.samples, self.nufft_op) -======= - return _NUFFT_OP.apply(x, self.nufft_op) ->>>>>>> origin/autodiff_ktraj def adj_op(self, kspace): r"""Compute the adjoint k-space -> image.""" diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 8b9cd550..b9263e8d 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -10,9 +10,7 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): - #FIXME: check if trajectory is torch. get_array_module : torch. Basically output must be torch if input is torch and everything must be done in torch. """Get the NDFT Fourier Matrix.""" - module = get_array_module(ktraj) ktraj = proper_trajectory(ktraj, normalize="unit") n = np.prod(shape) From 1d7f30eadc84018e64a8f025debe5d2485aaf37d Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Thu, 30 May 2024 17:03:55 +0200 Subject: [PATCH 09/45] update for adjoint --- src/mrinufft/operators/autodiff.py | 87 ++++++++++++++----- src/mrinufft/operators/base.py | 28 +++--- .../operators/interfaces/cufinufft.py | 16 +++- tests/test_autodiff.py | 4 +- 4 files changed, 92 insertions(+), 43 deletions(-) diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index 8d2842a0..fca5045e 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -29,27 +29,31 @@ def backward(ctx, dy): """Backward image -> k-space.""" (x, traj) = ctx.saved_tensors - im_size = x.size()[1:] - r = [torch.linspace(-size / 2, size / 2 - 1, size) for size in im_size] - grid_r = torch.meshgrid(*r, indexing="ij") - grid_r = torch.stack(grid_r, dim=0).type_as(x)[None, ...] - - grid_x = x * grid_r # Element-wise multiplication: x * r - nufft_dx_dom = torch.cat( - [ - ctx.nufft_op.op(grid_x[:, i : i + 1, :, :]) - for i in range(grid_x.size(1)) - ], - dim=1, - ) # Compute A(x * r) for each channel and concatenate along this dimension - - grad_traj = torch.transpose( - (-1j * torch.conj(dy) * nufft_dx_dom).squeeze(), 0, 1 - ).type_as( - traj - ) # Compute gradient with respect to trajectory: -i * dy' * A(x * r) - - return ctx.nufft_op.adj_op(dy), grad_traj, None + if ctx.nufft_op._grad_wrt_data: + grad_data = ctx.nufft_op.adj_op(dy) + if ctx.nufft_op._grad_wrt_traj: + im_size = x.size()[1:] + r = [torch.linspace(-size / 2, size / 2 - 1, size) for size in im_size] + grid_r = torch.meshgrid(*r, indexing="ij") + grid_r = torch.stack(grid_r, dim=0).type_as(x)[None, ...] + + grid_x = x * grid_r # Element-wise multiplication: x * r + nufft_dx_dom = torch.cat( + [ + ctx.nufft_op.op(grid_x[:, i : i + 1, :, :]) + for i in range(grid_x.size(1)) + ], + dim=1, + ) # Compute A(x * r) for each channel and concatenate along this dimension + + grad_traj = torch.transpose( + (-1j * torch.conj(dy) * nufft_dx_dom).squeeze(), 0, 1 + ).type_as( + traj + ) # Compute gradient with respect to trajectory: -i * dy' * A(x * r) + else: + grad_traj = None + return grad_data, grad_traj, None class _NUFFT_ADJOP(torch.autograd.Function): @@ -66,8 +70,40 @@ def forward(ctx, y, traj, nufft_op): def backward(ctx, dx): """Backward kspace -> image.""" (y, traj) = ctx.saved_tensors # y [1, 256] traj [256, 2] - grad_traj = None - return ctx.nufft_op.op(dx), grad_traj, None + + grad_data = None + grad_traj = None + if ctx.nufft_op._grad_wrt_data: + grad_data = ctx.nufft_op.op(dx) + + if ctx.nufft_op._grad_wrt_traj: + + ctx.nufft_op.raw_op.toggle_grad_traj() + + im_size = dx.size()[2:] + r = [torch.linspace(-size / 2, size / 2 - 1, size) for size in im_size] + grid_r = torch.meshgrid(*r, indexing="ij") + grid_r = torch.stack(grid_r, dim=0).type_as(dx)[None, ...] #[1, 2, 16, 16] + + + grid_dx = torch.conj(dx) * grid_r + inufft_dx_dom = torch.cat( + [ + ctx.nufft_op.op(grid_dx[:, i : i + 1, :, :]) + for i in range(grid_dx.size(1)) + ], + dim=1, + ) + + grad_traj = torch.transpose( + (1j * y * inufft_dx_dom).squeeze(), 0, 1 + ).type_as( + traj + ) + + ctx.nufft_op.raw_op.toggle_grad_traj() + + return grad_data, grad_traj, None class MRINufftAutoGrad(torch.nn.Module): @@ -79,11 +115,14 @@ class MRINufftAutoGrad(torch.nn.Module): nufft_op: Classic Non differentiable MRI-NUFFT operator. """ - def __init__(self, nufft_op): + def __init__(self, nufft_op, wrt_data= True, wrt_traj=False): + super().__init__() if nufft_op.squeeze_dims: raise ValueError("Squeezing dimensions is not " "supported for autodiff.") self.nufft_op = nufft_op + self.nufft_op._grad_wrt_traj = wrt_traj + self.nufft_op._grad_wrt_data = wrt_data def op(self, x): r"""Compute the forward image -> k-space.""" diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 0bec7fcf..d56f92ef 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -58,7 +58,7 @@ def list_backends(available_only=False): ] -def get_operator(backend_name: str, *args, autograd=None, **kwargs): +def get_operator(backend_name: str, *args, wrt_data, wrt_traj, **kwargs): """Return an MRI Fourier operator interface using the correct backend. Parameters @@ -102,11 +102,11 @@ class or instance of class if args or kwargs are given. if args or kwargs: operator = operator(*args, **kwargs) - if autograd: - if isinstance(operator, FourierOperatorBase): - operator = operator.make_autograd(variable=autograd) - else: # partial - operator = partial(operator.with_autograd, variable=autograd) + #if autograd: + if isinstance(operator, FourierOperatorBase): + operator = operator.make_autograd(wrt_data, wrt_traj) + else: # partial + operator = partial(operator.with_autograd, wrt_data, wrt_traj) return operator @@ -253,12 +253,12 @@ def with_off_resonnance_correction(self, B, C, indices): return MRIFourierCorrected(self, B, C, indices) - def make_autograd(self, variable="data"): + def make_autograd(self, wrt_data=True, wrt_traj=False): """Make a new Operator with autodiff support. Parameters ---------- - variable: str, default data + variable: , default data variable on which the gradient is computed with respect to. Returns @@ -273,11 +273,9 @@ def make_autograd(self, variable="data"): """ if not AUTOGRAD_AVAILABLE: raise ValueError("Autograd not available, ensure torch is installed.") - if variable == "data": - return MRINufftAutoGrad(self) - else: - raise ValueError(f"Autodiff with respect to {variable} is not supported.") - + + return MRINufftAutoGrad(self,wrt_data=wrt_data, wrt_traj=wrt_traj) + def compute_density(self, method=None): """Compute the density compensation weights and set it. @@ -448,9 +446,9 @@ def __repr__(self): ) @classmethod - def with_autograd(cls, variable, *args, **kwargs): + def with_autograd(cls, wrt_data, wrt_traj, *args, **kwargs): """Return a Fourier operator with autograd capabilities.""" - return cls(*args, **kwargs).make_autograd(variable) + return cls(*args, **kwargs).make_autograd(wrt_data, wrt_traj) class FourierOperatorCPU(FourierOperatorBase): diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index ece1c253..89209fa9 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -66,6 +66,7 @@ def __init__( # the first element is dummy to index type 1 with 1 # and type 2 with 2. self.plans = [None, None, None] + self.grad_plan = None for i in [1, 2]: self._make_plan(i, **kwargs) @@ -88,7 +89,16 @@ def _make_plan(self, typ, **kwargs): dtype=DTYPE_R2C[str(self.samples.dtype)], **kwargs, ) - + def _make_plan_grad(self, **kwargs): + self.grad_plan = Plan( + 1, + self.shape, + self.n_trans, + self.eps, + dtype=DTYPE_R2C[str(self.samples.dtype)], + isign = -1, + **kwargs, + ) def _set_pts(self, typ): self.plans[typ].setpts( @@ -110,7 +120,9 @@ def type1(self, coeff_data, grid_data): def type2(self, grid_data, coeff_data): """Type 2 transform. Uniform to non-uniform.""" return self.plans[2].execute(grid_data, coeff_data) - + + def toggle_grad_traj(self): + self.plans[1], self.grad_plan = self.grad_plan, self.plans[1] class MRICufiNUFFT(FourierOperatorBase): """MRI Transform operator, build around cufinufft. diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index 70e892cf..01f0fcc3 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -37,7 +37,7 @@ def operator(kspace_loc, shape, backend, autograd): """Create NUFFT operator with autodiff capabilities.""" kspace_loc = kspace_loc.astype(np.float32) - nufft = get_operator(backend_name=backend, autograd=autograd)( + nufft = get_operator(backend_name=backend, wrt_data=True, wrt_traj=True)( samples=kspace_loc, shape=shape, smaps=None, @@ -90,7 +90,7 @@ def test_adjoint_and_grad(operator, interface): ] - assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-2) #FIXME: plot the gradient + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-2) # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, ksp_data, retain_graph=True)[0] From d53a7ca643b8c5b61360a90f7dea3f861c3cf077 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 30 May 2024 20:50:08 +0200 Subject: [PATCH 10/45] Fixed working codes --- src/mrinufft/operators/autodiff.py | 11 +++- .../operators/interfaces/cufinufft.py | 26 ++++++---- tests/test_autodiff.py | 51 +++++++++---------- 3 files changed, 51 insertions(+), 37 deletions(-) diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index fca5045e..ec566ba5 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -72,14 +72,19 @@ def backward(ctx, dx): (y, traj) = ctx.saved_tensors # y [1, 256] traj [256, 2] grad_data = None - grad_traj = None + grad_traj = None + print("In AutoGrad") if ctx.nufft_op._grad_wrt_data: grad_data = ctx.nufft_op.op(dx) if ctx.nufft_op._grad_wrt_traj: - + print(ctx.nufft_op.raw_op.plans) + print(ctx.nufft_op.raw_op.grad_plan) ctx.nufft_op.raw_op.toggle_grad_traj() + print(ctx.nufft_op.raw_op.plans) + print(ctx.nufft_op.raw_op.grad_plan) + im_size = dx.size()[2:] r = [torch.linspace(-size / 2, size / 2 - 1, size) for size in im_size] grid_r = torch.meshgrid(*r, indexing="ij") @@ -122,6 +127,8 @@ def __init__(self, nufft_op, wrt_data= True, wrt_traj=False): raise ValueError("Squeezing dimensions is not " "supported for autodiff.") self.nufft_op = nufft_op self.nufft_op._grad_wrt_traj = wrt_traj + if wrt_traj: + self.nufft_op.raw_op._make_plan_grad() self.nufft_op._grad_wrt_data = wrt_data def op(self, x): diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 89209fa9..9e5e7bcd 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -91,21 +91,29 @@ def _make_plan(self, typ, **kwargs): ) def _make_plan_grad(self, **kwargs): self.grad_plan = Plan( - 1, + 2, self.shape, self.n_trans, self.eps, dtype=DTYPE_R2C[str(self.samples.dtype)], - isign = -1, + isign = 1, **kwargs, ) + self._set_pts(typ='grad') def _set_pts(self, typ): - self.plans[typ].setpts( - cp.array(self.samples[:, 0], copy=False), - cp.array(self.samples[:, 1], copy=False), - cp.array(self.samples[:, 2], copy=False) if self.ndim == 3 else None, - ) + if typ == 'grad': + self.grad_plan.setpts( + cp.array(self.samples[:, 0], copy=False), + cp.array(self.samples[:, 1], copy=False), + cp.array(self.samples[:, 2], copy=False) if self.ndim == 3 else None, + ) + else: + self.plans[typ].setpts( + cp.array(self.samples[:, 0], copy=False), + cp.array(self.samples[:, 1], copy=False), + cp.array(self.samples[:, 2], copy=False) if self.ndim == 3 else None, + ) def _destroy_plan(self, typ): if self.plans[typ] is not None: @@ -122,7 +130,8 @@ def type2(self, grid_data, coeff_data): return self.plans[2].execute(grid_data, coeff_data) def toggle_grad_traj(self): - self.plans[1], self.grad_plan = self.grad_plan, self.plans[1] + """Toggle between the gradient trajectory and the plan for type 1 transform.""" + self.plans[2], self.grad_plan = self.grad_plan, self.plans[2] class MRICufiNUFFT(FourierOperatorBase): """MRI Transform operator, build around cufinufft. @@ -191,7 +200,6 @@ def __init__( verbose=False, squeeze_dims=False, n_trans=1, - grad_traj=False, **kwargs, ): # run the availaility check here to get detailled output. diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index 01f0fcc3..fdf85fd3 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -6,9 +6,6 @@ from pytest_cases import parametrize_with_cases, parametrize, fixture from case_trajectories import CasesTrajectories from mrinufft.operators import get_operator -import warnings -from mrinufft._utils import proper_trajectory -import matplotlib.pyplot as plt from helpers import ( kspace_from_op, @@ -26,7 +23,6 @@ @fixture(scope="module") @parametrize(backend=["cufinufft", "finufft", "gpunufft"]) -@parametrize(autograd=["data"]) @parametrize_with_cases( "kspace_loc, shape", cases=[ @@ -34,10 +30,15 @@ CasesTrajectories.case_nyquist_radial2D, ], # 2D cases only for reduced memory footprint. ) -def operator(kspace_loc, shape, backend, autograd): +def operator(kspace_loc, shape, backend): """Create NUFFT operator with autodiff capabilities.""" kspace_loc = kspace_loc.astype(np.float32) - nufft = get_operator(backend_name=backend, wrt_data=True, wrt_traj=True)( + wrt_traj = True + if backend == "gpunufft": + # Gradient wrt to trajectory is not yet supported for gpunufft + wrt_traj = False + + nufft = get_operator(backend_name=backend, wrt_data=True, wrt_traj=wrt_traj)( samples=kspace_loc, shape=shape, smaps=None, @@ -76,21 +77,20 @@ def test_adjoint_and_grad(operator, interface): ) loss_nufft = torch.mean(torch.abs(adj_data - img_data) ** 2) loss_ndft = torch.mean(torch.abs(adj_data_ndft - img_data) ** 2) - - # Check if nufft and ndft w.r.t trajectory are close in the backprop - gradient_ndft_ktraj = torch.autograd.grad( - loss_ndft, operator.samples, retain_graph=True - )[ - 0 - ] - gradient_nufft_ktraj = torch.autograd.grad( - loss_nufft, operator.samples, retain_graph=True - )[ - 0 - ] - - assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-2) + if operator.backend != "gpunufft": + # Check if nufft and ndft w.r.t trajectory are close in the backprop + gradient_ndft_ktraj = torch.autograd.grad( + loss_ndft, operator.samples, retain_graph=True + )[ + 0 + ] + gradient_nufft_ktraj = torch.autograd.grad( + loss_nufft, operator.samples, retain_graph=True + )[ + 0 + ] + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-7) # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, ksp_data, retain_graph=True)[0] @@ -128,13 +128,12 @@ def test_forward_and_grad(operator, interface): loss_ndft, operator.samples, retain_graph=True )[ 0 - ] # ktraj就是operator.samples实际上 - gradient_nufft_ktraj = torch.autograd.grad( - loss_nufft, operator.samples, retain_graph=True - )[ # - 0 ] - assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-7) + if operator.backend != "gpunufft": + gradient_nufft_ktraj = torch.autograd.grad( + loss_nufft, operator.samples, retain_graph=True + )[0] + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-7) # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, img_data, retain_graph=True)[0] From df5006f4a12bb8b242b685a658620c0f8cbbd54d Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 30 May 2024 20:51:28 +0200 Subject: [PATCH 11/45] Remove vscode stuff --- .vscode/launch.json | 35 ----------------------------------- .vscode/settings.json | 6 ------ 2 files changed, 41 deletions(-) delete mode 100644 .vscode/launch.json delete mode 100644 .vscode/settings.json diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index cb49e2bd..00000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,35 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - - - { - "name": "Python Debugger: Current File", - "type": "debugpy", - "request": "launch", - "program": "${file}", - "module": "pytest", - "args": [ - "/volatile/Caini/mri-nufft/tests/test_autodiff.py" - ], - "console": "integratedTerminal", - "python": "/volatile/Caini/Envs/projector/bin/python", - "env": { - "PYTHONPATH": "/mri-nufft" - }, - - }, - { - "name": "DTest", - "type": "debugpy", - "request": "launch", - "program": "${file}", - "purpose": ["debug-test"], - "console": "integratedTerminal", - "justMyCode": false - }, - ] -} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index d23fb8e7..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,6 +0,0 @@ -{ - "python.testing.pytestArgs": [ - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true -} \ No newline at end of file From 5c29929a9cb45a4034bec9a193e783095c6facc4 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 30 May 2024 20:52:12 +0200 Subject: [PATCH 12/45] Ruff and black --- src/mrinufft/_utils.py | 4 +-- src/mrinufft/operators/autodiff.py | 24 +++++++-------- src/mrinufft/operators/base.py | 8 ++--- .../operators/interfaces/cufinufft.py | 10 ++++--- tests/test_autodiff.py | 30 ++++++++----------- 5 files changed, 35 insertions(+), 41 deletions(-) diff --git a/src/mrinufft/_utils.py b/src/mrinufft/_utils.py index 0a57c60a..2b85affa 100644 --- a/src/mrinufft/_utils.py +++ b/src/mrinufft/_utils.py @@ -57,7 +57,7 @@ def auto_cast(array, dtype: DTypeLike): return array.astype(dtype) -def proper_trajectory(trajectory, normalize="pi"): +def proper_trajectory(trajectory, normalize="pi"): """Normalize the trajectory to be used by NUFFT operators. Parameters @@ -75,7 +75,7 @@ def proper_trajectory(trajectory, normalize="pi"): The normalized trajectory of shape (Nc * Ns, dim) or (Ns, dim) in -pi, pi """ # flatten to a list of point - module = get_array_module(trajectory) # check if the trajectory is a tensor + module = get_array_module(trajectory) # check if the trajectory is a tensor try: new_traj = ( trajectory.clone() diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index ec566ba5..f0869855 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -70,13 +70,13 @@ def forward(ctx, y, traj, nufft_op): def backward(ctx, dx): """Backward kspace -> image.""" (y, traj) = ctx.saved_tensors # y [1, 256] traj [256, 2] - - grad_data = None + + grad_data = None grad_traj = None - print("In AutoGrad") - if ctx.nufft_op._grad_wrt_data: + print("In AutoGrad") + if ctx.nufft_op._grad_wrt_data: grad_data = ctx.nufft_op.op(dx) - + if ctx.nufft_op._grad_wrt_traj: print(ctx.nufft_op.raw_op.plans) print(ctx.nufft_op.raw_op.grad_plan) @@ -84,13 +84,12 @@ def backward(ctx, dx): print(ctx.nufft_op.raw_op.plans) print(ctx.nufft_op.raw_op.grad_plan) - + im_size = dx.size()[2:] r = [torch.linspace(-size / 2, size / 2 - 1, size) for size in im_size] grid_r = torch.meshgrid(*r, indexing="ij") - grid_r = torch.stack(grid_r, dim=0).type_as(dx)[None, ...] #[1, 2, 16, 16] + grid_r = torch.stack(grid_r, dim=0).type_as(dx)[None, ...] # [1, 2, 16, 16] - grid_dx = torch.conj(dx) * grid_r inufft_dx_dom = torch.cat( [ @@ -102,12 +101,10 @@ def backward(ctx, dx): grad_traj = torch.transpose( (1j * y * inufft_dx_dom).squeeze(), 0, 1 - ).type_as( - traj - ) + ).type_as(traj) ctx.nufft_op.raw_op.toggle_grad_traj() - + return grad_data, grad_traj, None @@ -120,8 +117,7 @@ class MRINufftAutoGrad(torch.nn.Module): nufft_op: Classic Non differentiable MRI-NUFFT operator. """ - def __init__(self, nufft_op, wrt_data= True, wrt_traj=False): - + def __init__(self, nufft_op, wrt_data=True, wrt_traj=False): super().__init__() if nufft_op.squeeze_dims: raise ValueError("Squeezing dimensions is not " "supported for autodiff.") diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index d56f92ef..2427bdeb 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -102,7 +102,7 @@ class or instance of class if args or kwargs are given. if args or kwargs: operator = operator(*args, **kwargs) - #if autograd: + # if autograd: if isinstance(operator, FourierOperatorBase): operator = operator.make_autograd(wrt_data, wrt_traj) else: # partial @@ -273,9 +273,9 @@ def make_autograd(self, wrt_data=True, wrt_traj=False): """ if not AUTOGRAD_AVAILABLE: raise ValueError("Autograd not available, ensure torch is installed.") - - return MRINufftAutoGrad(self,wrt_data=wrt_data, wrt_traj=wrt_traj) - + + return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj) + def compute_density(self, method=None): """Compute the density compensation weights and set it. diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 9e5e7bcd..173effd2 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -89,6 +89,7 @@ def _make_plan(self, typ, **kwargs): dtype=DTYPE_R2C[str(self.samples.dtype)], **kwargs, ) + def _make_plan_grad(self, **kwargs): self.grad_plan = Plan( 2, @@ -96,13 +97,13 @@ def _make_plan_grad(self, **kwargs): self.n_trans, self.eps, dtype=DTYPE_R2C[str(self.samples.dtype)], - isign = 1, + isign=1, **kwargs, ) - self._set_pts(typ='grad') + self._set_pts(typ="grad") def _set_pts(self, typ): - if typ == 'grad': + if typ == "grad": self.grad_plan.setpts( cp.array(self.samples[:, 0], copy=False), cp.array(self.samples[:, 1], copy=False), @@ -128,11 +129,12 @@ def type1(self, coeff_data, grid_data): def type2(self, grid_data, coeff_data): """Type 2 transform. Uniform to non-uniform.""" return self.plans[2].execute(grid_data, coeff_data) - + def toggle_grad_traj(self): """Toggle between the gradient trajectory and the plan for type 1 transform.""" self.plans[2], self.grad_plan = self.grad_plan, self.plans[2] + class MRICufiNUFFT(FourierOperatorBase): """MRI Transform operator, build around cufinufft. diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index fdf85fd3..ed0721a1 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -37,7 +37,7 @@ def operator(kspace_loc, shape, backend): if backend == "gpunufft": # Gradient wrt to trajectory is not yet supported for gpunufft wrt_traj = False - + nufft = get_operator(backend_name=backend, wrt_data=True, wrt_traj=wrt_traj)( samples=kspace_loc, shape=shape, @@ -64,33 +64,29 @@ def test_adjoint_and_grad(operator, interface): if torch.is_tensor(operator.samples): operator.samples = operator.samples.cpu().detach().numpy() - operator.samples = to_interface(operator.samples, interface=interface) - ksp_data = to_interface(kspace_from_op(operator), interface=interface) - img_data = to_interface(image_from_op(operator), interface=interface) + operator.samples = to_interface(operator.samples, interface=interface) + ksp_data = to_interface(kspace_from_op(operator), interface=interface) + img_data = to_interface(image_from_op(operator), interface=interface) ksp_data.requires_grad = True operator.samples.requires_grad = True with torch.autograd.set_detect_anomaly(True): - adj_data = operator.adj_op(ksp_data).reshape(img_data.shape) + adj_data = operator.adj_op(ksp_data).reshape(img_data.shape) adj_data_ndft = (ndft_matrix(operator).conj().T @ ksp_data.flatten()).reshape( adj_data.shape ) loss_nufft = torch.mean(torch.abs(adj_data - img_data) ** 2) loss_ndft = torch.mean(torch.abs(adj_data_ndft - img_data) ** 2) - + if operator.backend != "gpunufft": # Check if nufft and ndft w.r.t trajectory are close in the backprop gradient_ndft_ktraj = torch.autograd.grad( loss_ndft, operator.samples, retain_graph=True - )[ - 0 - ] + )[0] gradient_nufft_ktraj = torch.autograd.grad( loss_nufft, operator.samples, retain_graph=True - )[ - 0 - ] - assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-7) + )[0] + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-7) # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, ksp_data, retain_graph=True)[0] @@ -119,16 +115,16 @@ def test_forward_and_grad(operator, interface): with torch.autograd.set_detect_anomaly(True): ksp_data = operator.op(img_data).reshape(ksp_data_ref.shape) - ksp_data_ndft = (ndft_matrix(operator) @ img_data.flatten()).reshape(ksp_data.shape) + ksp_data_ndft = (ndft_matrix(operator) @ img_data.flatten()).reshape( + ksp_data.shape + ) loss_nufft = torch.mean(torch.abs(ksp_data - ksp_data_ref) ** 2) loss_ndft = torch.mean(torch.abs(ksp_data_ndft - ksp_data_ref) ** 2) # Check if nufft and ndft w.r.t trajectory are close in the backprop gradient_ndft_ktraj = torch.autograd.grad( loss_ndft, operator.samples, retain_graph=True - )[ - 0 - ] + )[0] if operator.backend != "gpunufft": gradient_nufft_ktraj = torch.autograd.grad( loss_nufft, operator.samples, retain_graph=True From 156f3223c645d5e487f3b4400146d037062545b6 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 30 May 2024 20:55:16 +0200 Subject: [PATCH 13/45] Merging --- src/mrinufft/extras/smaps.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index 9051b8d3..43a594b1 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -87,7 +87,9 @@ def _extract_kspace_center( a_0 = 0.5 if window_fun in ["hann", "hanning"] else 0.53836 window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold) elif window_fun == "ellipse": - window = xp.sum(kspace_loc**2 / xp.asarray(threshold) ** 2, axis=1) <= 1 + window = ( + xp.sum(kspace_loc**2 / xp.asarray(threshold) ** 2, axis=1) <= 1 + ) else: raise ValueError("Unsupported window function.") data_thresholded = window * kspace_data From 6b5ef8eb1205180938bef29640076eb75b72cdeb Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 30 May 2024 20:58:59 +0200 Subject: [PATCH 14/45] Remove torch dependence --- src/mrinufft/operators/interfaces/nudft_numpy.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index b9263e8d..f7ddc1a0 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -4,7 +4,6 @@ import numpy as np import scipy as sp -import torch from ..base import FourierOperatorCPU from mrinufft._utils import proper_trajectory, get_array_module @@ -17,6 +16,7 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): ndim = len(shape) if module.__name__ == "torch": + torch = module device = ktraj.device dtype = torch.complex64 r = [torch.linspace(-s / 2, s / 2 - 1, s, device=device) for s in shape] @@ -24,7 +24,6 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): grid_r = torch.reshape(torch.stack(grid_r), (ndim, n)).to(device) traj_grid = torch.matmul(ktraj, grid_r) matrix = torch.exp(-2j * np.pi * traj_grid).to(dtype).to(device).clone() - else: r = [np.linspace(-s / 2, s / 2 - 1, s) for s in shape] grid_r = np.reshape(np.meshgrid(*r, indexing="ij"), (ndim, np.prod(shape))) From f6a6e7a678fae34172fed6cb246e2f462cef85ce Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 30 May 2024 21:01:43 +0200 Subject: [PATCH 15/45] Remove bad usage of torch --- src/mrinufft/operators/interfaces/cufinufft.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index c3b9bfb0..3cffc551 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -1,5 +1,4 @@ """Provides Operator for MR Image processing on GPU.""" -import torch import warnings import numpy as np from mrinufft.operators.base import FourierOperatorBase, with_numpy_cupy From 0dca414761abcb47aa08222b280987a5d50d1b94 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 30 May 2024 21:03:05 +0200 Subject: [PATCH 16/45] Fix get_op --- src/mrinufft/operators/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 6e2d4509..52bd0f51 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -59,7 +59,7 @@ def list_backends(available_only=False): ] -def get_operator(backend_name: str, *args, wrt_data, wrt_traj, **kwargs): +def get_operator(backend_name: str, wrt_data=False, wrt_traj=False, *args, **kwargs): """Return an MRI Fourier operator interface using the correct backend. Parameters From fa5a60310d6e7593c5fd2e6045aff146e6ea1a06 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Thu, 30 May 2024 21:13:10 +0200 Subject: [PATCH 17/45] Added squeeze dims check right --- src/mrinufft/operators/autodiff.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index f0869855..42c6efcb 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -119,7 +119,7 @@ class MRINufftAutoGrad(torch.nn.Module): def __init__(self, nufft_op, wrt_data=True, wrt_traj=False): super().__init__() - if nufft_op.squeeze_dims: + if wrt_data or wrt_traj and nufft_op.squeeze_dims: raise ValueError("Squeezing dimensions is not " "supported for autodiff.") self.nufft_op = nufft_op self.nufft_op._grad_wrt_traj = wrt_traj From e1264d86425d510d8346ff1fb080c8924e93f094 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Fri, 31 May 2024 11:07:41 +0200 Subject: [PATCH 18/45] Fixes --- src/mrinufft/operators/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 52bd0f51..fc3e6af2 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -108,6 +108,7 @@ class or instance of class if args or kwargs are given. operator = operator.make_autograd(wrt_data, wrt_traj) else: # partial operator = partial(operator.with_autograd, wrt_data, wrt_traj) + operator.__name__ = operator.backend return operator From 3aa6574a9f03031f6b54ea22730d307cd9d6402d Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Fri, 31 May 2024 11:15:05 +0200 Subject: [PATCH 19/45] update for finufft --- .vscode/settings.json | 7 +++++ src/mrinufft/operators/interfaces/finufft.py | 32 +++++++++++++++++--- 2 files changed, 34 insertions(+), 5 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..9b388533 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/src/mrinufft/operators/interfaces/finufft.py b/src/mrinufft/operators/interfaces/finufft.py index f84ef41d..dc6b0609 100644 --- a/src/mrinufft/operators/interfaces/finufft.py +++ b/src/mrinufft/operators/interfaces/finufft.py @@ -32,6 +32,7 @@ def __init__( # the first element is dummy to index type 1 with 1 # and type 2 with 2. self.plans = [None, None, None] + self.grad_plan = None for i in [1, 2]: self._make_plan(i, **kwargs) @@ -47,11 +48,29 @@ def _make_plan(self, typ, **kwargs): **kwargs, ) + def _make_plan_grad(self, **kwargs): + self.grad_plan = Plan( + 2, + self.shape, + self.n_trans, + self.eps, + dtype="complex64" if self.samples.dtype == "float32" else "complex128", + isign = 1, + **kwargs, + ) + self._set_pts(typ="grad") + def _set_pts(self, typ): - fpts_axes = [None, None, None] - for i in range(self.ndim): - fpts_axes[i] = np.array(self.samples[:, i], dtype=self.samples.dtype) - self.plans[typ].setpts(*fpts_axes) + if typ == "grad": + fpts_axes = [None, None, None] + for i in range(self.ndim): + fpts_axes[i] = np.array(self.samples[:, i], dtype=self.samples.dtype) + self.grad_plan.setpts(*fpts_axes) + else: + fpts_axes = [None, None, None] + for i in range(self.ndim): + fpts_axes[i] = np.array(self.samples[:, i], dtype=self.samples.dtype) + self.plans[typ].setpts(*fpts_axes) def adj_op(self, coeffs_data, grid_data): """Type 1 transform. Non Uniform to Uniform.""" @@ -66,7 +85,10 @@ def op(self, coeffs_data, grid_data): grid_data = grid_data.reshape(self.shape) coeffs_data = coeffs_data.reshape(len(self.samples)) return self.plans[2].execute(grid_data, coeffs_data) - + + def toggle_grad_traj(self): + """Toggle between the gradient trajectory and the plan for type 1 transform.""" + self.plans[2], self.grad_plan = self.grad_plan, self.plans[2] class MRIfinufft(FourierOperatorCPU): """MRI Transform Operator using finufft. From c0868971d621bdb3038dca3b9d953e14b407cc7d Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Fri, 31 May 2024 13:51:22 +0200 Subject: [PATCH 20/45] Add support to gpuNUFFT --- src/mrinufft/_utils.py | 2 +- src/mrinufft/operators/autodiff.py | 26 ++++++++----- src/mrinufft/operators/base.py | 13 +++---- src/mrinufft/operators/interfaces/gpunufft.py | 31 +++++++++------- tests/case_trajectories.py | 8 +++- tests/test_autodiff.py | 37 ++++++++----------- 6 files changed, 63 insertions(+), 54 deletions(-) diff --git a/src/mrinufft/_utils.py b/src/mrinufft/_utils.py index 252aee7a..8c0e70c6 100644 --- a/src/mrinufft/_utils.py +++ b/src/mrinufft/_utils.py @@ -35,7 +35,7 @@ } try: from tensorflow.experimental import numpy as tnp - + ARRAY_LIBS["tensorflow"] = (tnp, tnp.ndarray) except ImportError: pass diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index 42c6efcb..3676e861 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -33,7 +33,13 @@ def backward(ctx, dy): grad_data = ctx.nufft_op.adj_op(dy) if ctx.nufft_op._grad_wrt_traj: im_size = x.size()[1:] - r = [torch.linspace(-size / 2, size / 2 - 1, size) for size in im_size] + factor = 1 + if ctx.nufft_op.backend == "gpunufft": + factor *= np.pi * 2 + r = [ + torch.linspace(-size / 2, size / 2 - 1, size)*factor + for size in im_size + ] grid_r = torch.meshgrid(*r, indexing="ij") grid_r = torch.stack(grid_r, dim=0).type_as(x)[None, ...] @@ -78,15 +84,15 @@ def backward(ctx, dx): grad_data = ctx.nufft_op.op(dx) if ctx.nufft_op._grad_wrt_traj: - print(ctx.nufft_op.raw_op.plans) - print(ctx.nufft_op.raw_op.grad_plan) ctx.nufft_op.raw_op.toggle_grad_traj() - - print(ctx.nufft_op.raw_op.plans) - print(ctx.nufft_op.raw_op.grad_plan) - im_size = dx.size()[2:] - r = [torch.linspace(-size / 2, size / 2 - 1, size) for size in im_size] + factor = 1 + if ctx.nufft_op.backend == "gpunufft": + factor *= np.pi * 2 + r = [ + torch.linspace(-size / 2, size / 2 - 1, size)*factor + for size in im_size + ] grid_r = torch.meshgrid(*r, indexing="ij") grid_r = torch.stack(grid_r, dim=0).type_as(dx)[None, ...] # [1, 2, 16, 16] @@ -119,11 +125,11 @@ class MRINufftAutoGrad(torch.nn.Module): def __init__(self, nufft_op, wrt_data=True, wrt_traj=False): super().__init__() - if wrt_data or wrt_traj and nufft_op.squeeze_dims: + if (wrt_data or wrt_traj) and nufft_op.squeeze_dims: raise ValueError("Squeezing dimensions is not " "supported for autodiff.") self.nufft_op = nufft_op self.nufft_op._grad_wrt_traj = wrt_traj - if wrt_traj: + if wrt_traj and self.nufft_op.backend != 'gpunufft': self.nufft_op.raw_op._make_plan_grad() self.nufft_op._grad_wrt_data = wrt_data diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index fc3e6af2..bda18051 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -59,18 +59,18 @@ def list_backends(available_only=False): ] -def get_operator(backend_name: str, wrt_data=False, wrt_traj=False, *args, **kwargs): +def get_operator(backend_name: str, wrt_data: bool=False, wrt_traj: bool=False, + *args, **kwargs): """Return an MRI Fourier operator interface using the correct backend. Parameters ---------- backend_name: str Backend name - - autograd: str, default None - if set to "data" will provide an operator with autodiff capabilities with - respect to it. - + wrt_data: bool, default False + if set gradients wrt to data and images will be available. + wrt_traj: bool, default False + if set gradients wrt to trajectory will be available. *args, **kwargs: Arguments to pass to the operator constructor. @@ -108,7 +108,6 @@ class or instance of class if args or kwargs are given. operator = operator.make_autograd(wrt_data, wrt_traj) else: # partial operator = partial(operator.with_autograd, wrt_data, wrt_traj) - operator.__name__ = operator.backend return operator diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index 58f075d9..8daa031c 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -183,6 +183,11 @@ def __init__( balance_workload, ) + def toggle_grad_traj(self): + """Toggle the gradient mode of the operator.""" + self.operator.toggle_grad_mode() + + def _reshape_image(self, image, direction="op"): """Reshape the image to the correct format.""" xp = get_array_module(image) @@ -383,7 +388,7 @@ def __init__( self.squeeze_dims = squeeze_dims self.compute_density(density) self.compute_smaps(smaps) - self.impl = RawGpuNUFFT( + self.raw_op = RawGpuNUFFT( samples=self.samples, shape=self.shape, n_coils=self.n_coils, @@ -411,10 +416,10 @@ def op(self, data, coeffs=None): """ B, C, XYZ, K = self.n_batchs, self.n_coils, self.shape, self.n_samples - op_func = self.impl.op + op_func = self.raw_op.op if is_cuda_array(data): - op_func = self.impl.op_direct - if not self.impl.use_gpu_direct: + op_func = self.raw_op.op_direct + if not self.raw_op.use_gpu_direct: warnings.warn( "Using direct GPU array without passing " "`use_gpu_direct=True`, this is memory inefficient." @@ -450,10 +455,10 @@ def adj_op(self, coeffs, data=None): """ B, C, XYZ, K = self.n_batchs, self.n_coils, self.shape, self.n_samples - adj_op_func = self.impl.adj_op + adj_op_func = self.raw_op.adj_op if is_cuda_array(coeffs): - adj_op_func = self.impl.adj_op_direct - if not self.impl.use_gpu_direct: + adj_op_func = self.raw_op.adj_op_direct + if not self.raw_op.use_gpu_direct: warnings.warn( "Using direct GPU array without passing " "`use_gpu_direct=True`, this is memory inefficient." @@ -474,7 +479,7 @@ def adj_op(self, coeffs, data=None): @property def uses_sense(self): """Return True if the Fourier Operator uses the SENSE method.""" - return self.impl.uses_sense + return self.raw_op.uses_sense @classmethod def pipe( @@ -513,7 +518,7 @@ def pipe( osf=1, **kwargs, ) - density_comp = grid_op.impl.operator.estimate_density_comp( + density_comp = grid_op.raw_op.operator.estimate_density_comp( max_iter=num_iterations ) if normalize: @@ -549,7 +554,7 @@ def get_lipschitz_cst(self, max_iter=10, tolerance=1e-5, **kwargs): squeeze_dims=True, **kwargs, ) - return tmp_op.impl.operator.get_spectral_radius( + return tmp_op.raw_op.operator.get_spectral_radius( max_iter=max_iter, tolerance=tolerance ) @@ -604,7 +609,7 @@ def data_consistency(self, image_data, obs_data): "but is memory inefficient!" ) grad_func = super().data_consistency - if not self.impl.use_gpu_direct: + if not self.raw_op.use_gpu_direct: warnings.warn( "Using direct GPU array without passing " "`use_gpu_direct=True`, this is memory inefficient." @@ -623,9 +628,9 @@ def _dc_host(self, image_data, obs_data): for i in range(B): tmp_img.set(image_data_[i]) obs_data_tmp.set(obs_data_[i]) - ksp_tmp = self.impl.op_direct(tmp_img) + ksp_tmp = self.raw_op.op_direct(tmp_img) ksp_tmp -= obs_data_tmp - final_img[i] = self.impl.adj_op_direct(ksp_tmp).get() + final_img[i] = self.raw_op.adj_op_direct(ksp_tmp).get() return final_img # TODO : For data consistency the workflow is currently: diff --git a/tests/case_trajectories.py b/tests/case_trajectories.py index e4426579..0f76e425 100644 --- a/tests/case_trajectories.py +++ b/tests/case_trajectories.py @@ -50,7 +50,13 @@ def case_nyquist_radial3D(self, Nc=32 * 4, Ns=16, Nr=32 * 4, N=32): trajectory = initialize_2D_radial(Nc, Ns) trajectory = rotate(trajectory, nb_rotations=Nr) return trajectory, (N, N, N) - + + def case_nyquist_radial3D_lowmem(self, Nc=2, Ns=16, Nr=2, N=10): + """Create a 3D radial trajectory.""" + trajectory = initialize_2D_radial(Nc, Ns) + trajectory = rotate(trajectory, nb_rotations=Nr) + return trajectory, (N, N, N) + def case_grid2D(self, N=16): """Create a 2D cartesian grid of frequencies locations.""" freq_1d = sp.fft.fftfreq(N) diff --git a/tests/test_autodiff.py b/tests/test_autodiff.py index ed0721a1..ede159ce 100644 --- a/tests/test_autodiff.py +++ b/tests/test_autodiff.py @@ -28,23 +28,18 @@ cases=[ CasesTrajectories.case_grid2D, CasesTrajectories.case_nyquist_radial2D, - ], # 2D cases only for reduced memory footprint. + CasesTrajectories.case_nyquist_radial3D_lowmem, + ], ) def operator(kspace_loc, shape, backend): """Create NUFFT operator with autodiff capabilities.""" kspace_loc = kspace_loc.astype(np.float32) - wrt_traj = True - if backend == "gpunufft": - # Gradient wrt to trajectory is not yet supported for gpunufft - wrt_traj = False - - nufft = get_operator(backend_name=backend, wrt_data=True, wrt_traj=wrt_traj)( + nufft = get_operator(backend_name=backend, wrt_data=True, wrt_traj=True)( samples=kspace_loc, shape=shape, smaps=None, squeeze_dims=False, # Squeezing breaks dimensions ! ) - return nufft @@ -78,15 +73,14 @@ def test_adjoint_and_grad(operator, interface): loss_nufft = torch.mean(torch.abs(adj_data - img_data) ** 2) loss_ndft = torch.mean(torch.abs(adj_data_ndft - img_data) ** 2) - if operator.backend != "gpunufft": - # Check if nufft and ndft w.r.t trajectory are close in the backprop - gradient_ndft_ktraj = torch.autograd.grad( - loss_ndft, operator.samples, retain_graph=True - )[0] - gradient_nufft_ktraj = torch.autograd.grad( - loss_nufft, operator.samples, retain_graph=True - )[0] - assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-7) + # Check if nufft and ndft w.r.t trajectory are close in the backprop + gradient_ndft_ktraj = torch.autograd.grad( + loss_ndft, operator.samples, retain_graph=True + )[0] + gradient_nufft_ktraj = torch.autograd.grad( + loss_nufft, operator.samples, retain_graph=True + )[0] + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-2) # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, ksp_data, retain_graph=True)[0] @@ -125,11 +119,10 @@ def test_forward_and_grad(operator, interface): gradient_ndft_ktraj = torch.autograd.grad( loss_ndft, operator.samples, retain_graph=True )[0] - if operator.backend != "gpunufft": - gradient_nufft_ktraj = torch.autograd.grad( - loss_nufft, operator.samples, retain_graph=True - )[0] - assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-7) + gradient_nufft_ktraj = torch.autograd.grad( + loss_nufft, operator.samples, retain_graph=True + )[0] + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-2) # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, img_data, retain_graph=True)[0] From f21fb29ccb8d71a2448e677e8b785c8cdbfc39a4 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Fri, 31 May 2024 13:54:06 +0200 Subject: [PATCH 21/45] Add su-port for gpunufft --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8f7e00f3..ccb33b5f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ dynamic = ["version"] [project.optional-dependencies] -gpunufft = ["gpuNUFFT>=0.7.5", "cupy-cuda11x"] +gpunufft = ["gpuNUFFT>=0.8.0", "cupy-cuda11x"] cufinufft = ["cufinufft", "cupy-cuda11x"] finufft = ["finufft"] pynfft = ["pynfft2", "cython<3.0.0"] From 713ac665a7965dee984c33ca717a1adfb2df1274 Mon Sep 17 00:00:00 2001 From: Caini Pan <88090141+alineyyy@users.noreply.github.com> Date: Fri, 31 May 2024 14:10:44 +0200 Subject: [PATCH 22/45] Delete .vscode directory --- .vscode/settings.json | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 9b388533..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "python.testing.pytestArgs": [ - "tests" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true -} \ No newline at end of file From abd9ffaa028f347546bed5fb7e3fc048dfb69006 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Sun, 2 Jun 2024 17:04:47 +0200 Subject: [PATCH 23/45] fix test_bindings --- tests/operators/test_bindings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/operators/test_bindings.py b/tests/operators/test_bindings.py index 8fb43257..18479988 100644 --- a/tests/operators/test_bindings.py +++ b/tests/operators/test_bindings.py @@ -18,7 +18,7 @@ ) def test_get_operator(backend, name): """Test the get_operator function.""" - assert mrinufft.get_operator(backend).__name__ == name + assert mrinufft.get_operator(backend).func.__self__.__name__ == name def test_get_operator_fail(): From 44d8579796c7a1a77b57a8c3a9688b285bf8774e Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Sun, 2 Jun 2024 17:53:21 +0200 Subject: [PATCH 24/45] fix gpunufft pipe --- src/mrinufft/density/nufft_based.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrinufft/density/nufft_based.py b/src/mrinufft/density/nufft_based.py index c2d4d7fa..4ae071ae 100644 --- a/src/mrinufft/density/nufft_based.py +++ b/src/mrinufft/density/nufft_based.py @@ -22,7 +22,7 @@ def pipe(traj, shape, backend="gpunufft", **kwargs): # here to avoid circular import from mrinufft.operators.base import get_operator - nufft_class = get_operator(backend) + nufft_class = get_operator(backend).func.__self__ if hasattr(nufft_class, "pipe"): return nufft_class.pipe(traj, shape, **kwargs) raise ValueError("backend does not have pipe iterations method.") From af33d5ce3ae19be18b8fc4e8e83abfe368322eef Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Sun, 2 Jun 2024 18:28:06 +0200 Subject: [PATCH 25/45] fix test-cpu --- src/mrinufft/operators/base.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index bda18051..57642285 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -24,7 +24,6 @@ AUTOGRAD_AVAILABLE = True try: - import torch from mrinufft.operators.autodiff import MRINufftAutoGrad except ImportError: AUTOGRAD_AVAILABLE = False From b82eb9018662adfea81d27172cf9bb4f935fe77e Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Sun, 2 Jun 2024 18:50:48 +0200 Subject: [PATCH 26/45] change test-ci --- .github/workflows/test-ci.yml | 5 +++++ src/mrinufft/operators/base.py | 1 + 2 files changed, 6 insertions(+) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index 72b708a4..cb34f6ee 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -42,6 +42,11 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install -e .[test] + + - name: Install Torch + shell: bash + run: | + python -m pip install torch --index-url https://download.pytorch.org/whl/cu118 - name: Install pynfft if: ${{ matrix.backend == 'pynfft' || env.ref_backend == 'pynfft' }} diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 57642285..bda18051 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -24,6 +24,7 @@ AUTOGRAD_AVAILABLE = True try: + import torch from mrinufft.operators.autodiff import MRINufftAutoGrad except ImportError: AUTOGRAD_AVAILABLE = False From 10f4b0b51646bb65d900fd78edf096a3a1747de4 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Sun, 2 Jun 2024 19:10:13 +0200 Subject: [PATCH 27/45] black style check --- .vscode/settings.json | 7 +++++++ src/mrinufft/_utils.py | 2 +- src/mrinufft/extras/smaps.py | 4 +--- src/mrinufft/operators/autodiff.py | 6 +++--- src/mrinufft/operators/base.py | 5 +++-- src/mrinufft/operators/interfaces/cufinufft.py | 1 + src/mrinufft/operators/interfaces/finufft.py | 5 +++-- src/mrinufft/operators/interfaces/gpunufft.py | 3 +-- tests/case_trajectories.py | 4 ++-- 9 files changed, 22 insertions(+), 15 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..9b388533 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file diff --git a/src/mrinufft/_utils.py b/src/mrinufft/_utils.py index 8c0e70c6..252aee7a 100644 --- a/src/mrinufft/_utils.py +++ b/src/mrinufft/_utils.py @@ -35,7 +35,7 @@ } try: from tensorflow.experimental import numpy as tnp - + ARRAY_LIBS["tensorflow"] = (tnp, tnp.ndarray) except ImportError: pass diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index 0bd7c789..2e2ed4af 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -87,9 +87,7 @@ def _extract_kspace_center( a_0 = 0.5 if window_fun in ["hann", "hanning"] else 0.53836 window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold) elif window_fun == "ellipse": - window = ( - xp.sum(kspace_loc**2 / xp.asarray(threshold) ** 2, axis=1) <= 1 - ) + window = xp.sum(kspace_loc**2 / xp.asarray(threshold) ** 2, axis=1) <= 1 else: raise ValueError("Unsupported window function.") data_thresholded = window * kspace_data diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index 3676e861..012fcb18 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -37,7 +37,7 @@ def backward(ctx, dy): if ctx.nufft_op.backend == "gpunufft": factor *= np.pi * 2 r = [ - torch.linspace(-size / 2, size / 2 - 1, size)*factor + torch.linspace(-size / 2, size / 2 - 1, size) * factor for size in im_size ] grid_r = torch.meshgrid(*r, indexing="ij") @@ -90,7 +90,7 @@ def backward(ctx, dx): if ctx.nufft_op.backend == "gpunufft": factor *= np.pi * 2 r = [ - torch.linspace(-size / 2, size / 2 - 1, size)*factor + torch.linspace(-size / 2, size / 2 - 1, size) * factor for size in im_size ] grid_r = torch.meshgrid(*r, indexing="ij") @@ -129,7 +129,7 @@ def __init__(self, nufft_op, wrt_data=True, wrt_traj=False): raise ValueError("Squeezing dimensions is not " "supported for autodiff.") self.nufft_op = nufft_op self.nufft_op._grad_wrt_traj = wrt_traj - if wrt_traj and self.nufft_op.backend != 'gpunufft': + if wrt_traj and self.nufft_op.backend != "gpunufft": self.nufft_op.raw_op._make_plan_grad() self.nufft_op._grad_wrt_data = wrt_data diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index bda18051..e9db5342 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -59,8 +59,9 @@ def list_backends(available_only=False): ] -def get_operator(backend_name: str, wrt_data: bool=False, wrt_traj: bool=False, - *args, **kwargs): +def get_operator( + backend_name: str, wrt_data: bool = False, wrt_traj: bool = False, *args, **kwargs +): """Return an MRI Fourier operator interface using the correct backend. Parameters diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 3cffc551..7dd0efcd 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -1,4 +1,5 @@ """Provides Operator for MR Image processing on GPU.""" + import warnings import numpy as np from mrinufft.operators.base import FourierOperatorBase, with_numpy_cupy diff --git a/src/mrinufft/operators/interfaces/finufft.py b/src/mrinufft/operators/interfaces/finufft.py index dc6b0609..99fcc56b 100644 --- a/src/mrinufft/operators/interfaces/finufft.py +++ b/src/mrinufft/operators/interfaces/finufft.py @@ -55,7 +55,7 @@ def _make_plan_grad(self, **kwargs): self.n_trans, self.eps, dtype="complex64" if self.samples.dtype == "float32" else "complex128", - isign = 1, + isign=1, **kwargs, ) self._set_pts(typ="grad") @@ -85,11 +85,12 @@ def op(self, coeffs_data, grid_data): grid_data = grid_data.reshape(self.shape) coeffs_data = coeffs_data.reshape(len(self.samples)) return self.plans[2].execute(grid_data, coeffs_data) - + def toggle_grad_traj(self): """Toggle between the gradient trajectory and the plan for type 1 transform.""" self.plans[2], self.grad_plan = self.grad_plan, self.plans[2] + class MRIfinufft(FourierOperatorCPU): """MRI Transform Operator using finufft. diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index 8daa031c..d976bc5c 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -186,8 +186,7 @@ def __init__( def toggle_grad_traj(self): """Toggle the gradient mode of the operator.""" self.operator.toggle_grad_mode() - - + def _reshape_image(self, image, direction="op"): """Reshape the image to the correct format.""" xp = get_array_module(image) diff --git a/tests/case_trajectories.py b/tests/case_trajectories.py index 0f76e425..ccee154d 100644 --- a/tests/case_trajectories.py +++ b/tests/case_trajectories.py @@ -50,13 +50,13 @@ def case_nyquist_radial3D(self, Nc=32 * 4, Ns=16, Nr=32 * 4, N=32): trajectory = initialize_2D_radial(Nc, Ns) trajectory = rotate(trajectory, nb_rotations=Nr) return trajectory, (N, N, N) - + def case_nyquist_radial3D_lowmem(self, Nc=2, Ns=16, Nr=2, N=10): """Create a 3D radial trajectory.""" trajectory = initialize_2D_radial(Nc, Ns) trajectory = rotate(trajectory, nb_rotations=Nr) return trajectory, (N, N, N) - + def case_grid2D(self, N=16): """Create a 2D cartesian grid of frequencies locations.""" freq_1d = sp.fft.fftfreq(N) From 1fa060fa826b5ad089626ade2659ad6bfbf48458 Mon Sep 17 00:00:00 2001 From: Caini Pan <88090141+alineyyy@users.noreply.github.com> Date: Sun, 2 Jun 2024 19:40:59 +0200 Subject: [PATCH 28/45] Delete .vscode directory --- .vscode/settings.json | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 9b388533..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "python.testing.pytestArgs": [ - "tests" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true -} \ No newline at end of file From 7bcec2bf5b3af3fe10096b2602576486ac479de8 Mon Sep 17 00:00:00 2001 From: Chaithya G R Date: Mon, 3 Jun 2024 10:32:32 +0200 Subject: [PATCH 29/45] Moving the test_autodiff to operators, so that it is tested --- tests/{ => operators}/test_autodiff.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{ => operators}/test_autodiff.py (100%) diff --git a/tests/test_autodiff.py b/tests/operators/test_autodiff.py similarity index 100% rename from tests/test_autodiff.py rename to tests/operators/test_autodiff.py From b501a0c0a76eb4d8a1831c51b2e87dbcabd9c5ae Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Mon, 3 Jun 2024 15:41:58 +0200 Subject: [PATCH 30/45] update for comments --- .github/workflows/test-ci.yml | 5 --- pyproject.toml | 2 + src/mrinufft/_utils.py | 17 +++----- src/mrinufft/density/nufft_based.py | 2 +- src/mrinufft/operators/base.py | 2 +- .../operators/interfaces/cufinufft.py | 24 +++++------ src/mrinufft/operators/interfaces/finufft.py | 16 +++----- .../operators/interfaces/nudft_numpy.py | 40 +++++++++---------- tests/operators/test_autodiff.py | 1 - tests/operators/test_bindings.py | 2 +- 10 files changed, 47 insertions(+), 64 deletions(-) diff --git a/.github/workflows/test-ci.yml b/.github/workflows/test-ci.yml index cb34f6ee..72b708a4 100644 --- a/.github/workflows/test-ci.yml +++ b/.github/workflows/test-ci.yml @@ -42,11 +42,6 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install -e .[test] - - - name: Install Torch - shell: bash - run: | - python -m pip install torch --index-url https://download.pytorch.org/whl/cu118 - name: Install pynfft if: ${{ matrix.backend == 'pynfft' || env.ref_backend == 'pynfft' }} diff --git a/pyproject.toml b/pyproject.toml index ccb33b5f..6c334b22 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,8 @@ pynfft = ["pynfft2", "cython<3.0.0"] pynufft = ["pynufft"] io = ["pymapvbvd"] smaps = ["scikit-image"] +autodiff = ["torch"] + test = ["pytest<8.0.0", "pytest-cov", "pytest-xdist", "pytest-sugar", "pytest-cases"] dev = ["black", "isort", "ruff"] diff --git a/src/mrinufft/_utils.py b/src/mrinufft/_utils.py index 252aee7a..37d3105b 100644 --- a/src/mrinufft/_utils.py +++ b/src/mrinufft/_utils.py @@ -75,11 +75,11 @@ def proper_trajectory(trajectory, normalize="pi"): The normalized trajectory of shape (Nc * Ns, dim) or (Ns, dim) in -pi, pi """ # flatten to a list of point - module = get_array_module(trajectory) # check if the trajectory is a tensor + xp = get_array_module(trajectory) # check if the trajectory is a tensor try: new_traj = ( trajectory.clone() - if module.__name__ == "torch" + if xp.__name__ == "torch" else np.asarray(trajectory).copy() ) except Exception as e: @@ -89,24 +89,19 @@ def proper_trajectory(trajectory, normalize="pi"): new_traj = new_traj.reshape(-1, trajectory.shape[-1]) - max_abs_val = ( - torch.max(torch.abs(new_traj)) - if module.__name__ == "torch" - else np.max(np.abs(new_traj)) - ) + max_abs_val = xp.max(xp.abs(new_traj)) if normalize == "pi" and max_abs_val - 1e-4 < 0.5: warnings.warn( "Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)" ) - new_traj *= 2 * torch.pi if module.__name__ == "torch" else 2 * np.pi + new_traj *= 2 * xp.pi elif normalize == "unit" and max_abs_val - 1e-4 > 0.5: warnings.warn( "Samples will be rescaled to [-0.5, 0.5), assuming they were in [-pi, pi)" ) - new_traj *= ( - 1 / (2 * torch.pi) if module.__name__ == "torch" else 1 / (2 * np.pi) - ) + new_traj *= 1 / (2 * xp.pi) + if normalize == "unit" and max_abs_val >= 0.5: new_traj = (new_traj + 0.5) % 1 - 0.5 return new_traj diff --git a/src/mrinufft/density/nufft_based.py b/src/mrinufft/density/nufft_based.py index 4ae071ae..c2d4d7fa 100644 --- a/src/mrinufft/density/nufft_based.py +++ b/src/mrinufft/density/nufft_based.py @@ -22,7 +22,7 @@ def pipe(traj, shape, backend="gpunufft", **kwargs): # here to avoid circular import from mrinufft.operators.base import get_operator - nufft_class = get_operator(backend).func.__self__ + nufft_class = get_operator(backend) if hasattr(nufft_class, "pipe"): return nufft_class.pipe(traj, shape, **kwargs) raise ValueError("backend does not have pipe iterations method.") diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index e9db5342..d9800dcf 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -107,7 +107,7 @@ class or instance of class if args or kwargs are given. # if autograd: if isinstance(operator, FourierOperatorBase): operator = operator.make_autograd(wrt_data, wrt_traj) - else: # partial + elif(wrt_data or wrt_traj): # partial operator = partial(operator.with_autograd, wrt_data, wrt_traj) return operator diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 7dd0efcd..2f87056c 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -103,18 +103,12 @@ def _make_plan_grad(self, **kwargs): self._set_pts(typ="grad") def _set_pts(self, typ): - if typ == "grad": - self.grad_plan.setpts( - cp.array(self.samples[:, 0], copy=False), - cp.array(self.samples[:, 1], copy=False), - cp.array(self.samples[:, 2], copy=False) if self.ndim == 3 else None, - ) - else: - self.plans[typ].setpts( - cp.array(self.samples[:, 0], copy=False), - cp.array(self.samples[:, 1], copy=False), - cp.array(self.samples[:, 2], copy=False) if self.ndim == 3 else None, - ) + plan = self.grad_plan if typ == "grad" else self.plans[typ] + plan.setpts( + cp.array(self.samples[:, 0], copy=False), + cp.array(self.samples[:, 1], copy=False), + cp.array(self.samples[:, 2], copy=False) if self.ndim == 3 else None, + ) def _destroy_plan(self, typ): if self.plans[typ] is not None: @@ -122,6 +116,12 @@ def _destroy_plan(self, typ): del p self.plans[typ] = None + def _destroy_plan_grad(self): + if self.grad_plan is not None: + p = self.grad_plan + del p + self.grad_plan = None + def type1(self, coeff_data, grid_data): """Type 1 transform. Non Uniform to Uniform.""" return self.plans[1].execute(coeff_data, grid_data) diff --git a/src/mrinufft/operators/interfaces/finufft.py b/src/mrinufft/operators/interfaces/finufft.py index 99fcc56b..7509ce14 100644 --- a/src/mrinufft/operators/interfaces/finufft.py +++ b/src/mrinufft/operators/interfaces/finufft.py @@ -61,16 +61,12 @@ def _make_plan_grad(self, **kwargs): self._set_pts(typ="grad") def _set_pts(self, typ): - if typ == "grad": - fpts_axes = [None, None, None] - for i in range(self.ndim): - fpts_axes[i] = np.array(self.samples[:, i], dtype=self.samples.dtype) - self.grad_plan.setpts(*fpts_axes) - else: - fpts_axes = [None, None, None] - for i in range(self.ndim): - fpts_axes[i] = np.array(self.samples[:, i], dtype=self.samples.dtype) - self.plans[typ].setpts(*fpts_axes) + fpts_axes = [None, None, None] + for i in range(self.ndim): + fpts_axes[i] = np.array(self.samples[:, i], dtype=self.samples.dtype) + plan = self.grad_plan if typ == "grad" else self.plans[typ] + plan.setpts(*fpts_axes) + def adj_op(self, coeffs_data, grid_data): """Type 1 transform. Non Uniform to Uniform.""" diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index f7ddc1a0..116d11fa 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -14,31 +14,27 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): ktraj = proper_trajectory(ktraj, normalize="unit") n = np.prod(shape) ndim = len(shape) - + dtype = module.complex64 + device = getattr(ktraj, 'device', None) + + r = [module.linspace(-s / 2, s / 2 - 1, s) for s in shape] if module.__name__ == "torch": - torch = module - device = ktraj.device - dtype = torch.complex64 - r = [torch.linspace(-s / 2, s / 2 - 1, s, device=device) for s in shape] - grid_r = torch.meshgrid(r, indexing="ij") - grid_r = torch.reshape(torch.stack(grid_r), (ndim, n)).to(device) - traj_grid = torch.matmul(ktraj, grid_r) - matrix = torch.exp(-2j * np.pi * traj_grid).to(dtype).to(device).clone() - else: - r = [np.linspace(-s / 2, s / 2 - 1, s) for s in shape] - grid_r = np.reshape(np.meshgrid(*r, indexing="ij"), (ndim, np.prod(shape))) - traj_grid = ktraj @ grid_r - matrix = np.exp(-2j * np.pi * traj_grid, dtype=dtype) + r = [x.to(device) for x in r] + grid_r = module.meshgrid(r, indexing="ij") + grid_r = module.reshape(module.stack(grid_r), (ndim, n)) + traj_grid = module.matmul(ktraj, grid_r) + matrix = module.exp(-2j * module.pi * traj_grid).to(dtype).to(device).clone() if module.__name__ == "torch" else ( + module.exp(-2j * module.pi * traj_grid, dtype=dtype) + ) + if normalize: - matrix /= ( - ( - torch.sqrt(torch.tensor(np.prod(shape), device=device)) - * torch.pow(torch.sqrt(torch.tensor(2, device=device)), ndim) - ) - if module.__name__ == "torch" - else (np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), len(shape))) - ) + norm_factor = ( + module.sqrt(module.prod(module.tensor(shape, device=device))) * module.pow(module.sqrt(module.tensor(2, device=device)), ndim) + if module.__name__ == "torch" else ( + module.sqrt(module.prod(shape)) * module.power(module.sqrt(2), ndim) + )) + matrix /= norm_factor return matrix diff --git a/tests/operators/test_autodiff.py b/tests/operators/test_autodiff.py index ede159ce..ebdaeb7e 100644 --- a/tests/operators/test_autodiff.py +++ b/tests/operators/test_autodiff.py @@ -43,7 +43,6 @@ def operator(kspace_loc, shape, backend): return nufft -# @fixture(scope="module") def ndft_matrix(operator): """Get the NDFT matrix from the operator.""" return get_fourier_matrix(operator.samples, operator.shape, normalize=True) diff --git a/tests/operators/test_bindings.py b/tests/operators/test_bindings.py index 18479988..8fb43257 100644 --- a/tests/operators/test_bindings.py +++ b/tests/operators/test_bindings.py @@ -18,7 +18,7 @@ ) def test_get_operator(backend, name): """Test the get_operator function.""" - assert mrinufft.get_operator(backend).func.__self__.__name__ == name + assert mrinufft.get_operator(backend).__name__ == name def test_get_operator_fail(): From e936fcab3577731b64288066955e68d1673faef3 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Mon, 3 Jun 2024 15:57:51 +0200 Subject: [PATCH 31/45] fix get_fourier_matrix --- src/mrinufft/operators/interfaces/nudft_numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 116d11fa..11fe488a 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -21,7 +21,7 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): if module.__name__ == "torch": r = [x.to(device) for x in r] - grid_r = module.meshgrid(r, indexing="ij") + grid_r = module.meshgrid(*r, indexing="ij") grid_r = module.reshape(module.stack(grid_r), (ndim, n)) traj_grid = module.matmul(ktraj, grid_r) matrix = module.exp(-2j * module.pi * traj_grid).to(dtype).to(device).clone() if module.__name__ == "torch" else ( From c60c64162dd7cdc2f87fe742298a39e92d47df76 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Mon, 3 Jun 2024 16:24:08 +0200 Subject: [PATCH 32/45] style check --- examples/example_density.py | 2 +- src/mrinufft/_utils.py | 4 ++-- src/mrinufft/density/geometry_based.py | 2 +- src/mrinufft/extras/smaps.py | 4 +++- src/mrinufft/operators/base.py | 2 +- .../operators/interfaces/cufinufft.py | 12 +++++------ src/mrinufft/operators/interfaces/finufft.py | 1 - .../operators/interfaces/nudft_numpy.py | 21 +++++++++++-------- src/mrinufft/trajectories/maths.py | 8 +++---- src/mrinufft/trajectories/tools.py | 4 ++-- src/mrinufft/trajectories/trajectory3D.py | 10 ++++----- 11 files changed, 37 insertions(+), 33 deletions(-) diff --git a/examples/example_density.py b/examples/example_density.py index 20f45b3c..0497b4a8 100644 --- a/examples/example_density.py +++ b/examples/example_density.py @@ -110,7 +110,7 @@ # %% flat_traj = traj.reshape(-1, 2) -weights = np.sqrt(np.sum(flat_traj**2, axis=1)) +weights = np.sqrt(np.sum(flat_traj ** 2, axis=1)) nufft = get_operator("finufft")(traj, shape=mri_2D.shape, density=weights) adjoint_manual = nufft.adj_op(kspace) fig, axs = plt.subplots(1, 3, figsize=(15, 5)) diff --git a/src/mrinufft/_utils.py b/src/mrinufft/_utils.py index 37d3105b..bdccc75a 100644 --- a/src/mrinufft/_utils.py +++ b/src/mrinufft/_utils.py @@ -95,12 +95,12 @@ def proper_trajectory(trajectory, normalize="pi"): warnings.warn( "Samples will be rescaled to [-pi, pi), assuming they were in [-0.5, 0.5)" ) - new_traj *= 2 * xp.pi + new_traj *= 2 * xp.pi elif normalize == "unit" and max_abs_val - 1e-4 > 0.5: warnings.warn( "Samples will be rescaled to [-0.5, 0.5), assuming they were in [-pi, pi)" ) - new_traj *= 1 / (2 * xp.pi) + new_traj *= 1 / (2 * xp.pi) if normalize == "unit" and max_abs_val >= 0.5: new_traj = (new_traj + 0.5) % 1 - 0.5 diff --git a/src/mrinufft/density/geometry_based.py b/src/mrinufft/density/geometry_based.py index 4dc0ecc5..cb091831 100644 --- a/src/mrinufft/density/geometry_based.py +++ b/src/mrinufft/density/geometry_based.py @@ -87,7 +87,7 @@ def voronoi_unique(traj, *args, **kwargs): # For edge point (infinite voronoi cells) we extrapolate from neighbours # Initial implementation in Jeff Fessler's MIRT - rho = np.sum(traj**2, axis=1) + rho = np.sum(traj ** 2, axis=1) igood = (rho > 0.6 * np.max(rho)) & ~np.isinf(wi) if len(igood) < 10: print("dubious extrapolation with", len(igood), "points") diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index 2e2ed4af..fadb365e 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -87,7 +87,9 @@ def _extract_kspace_center( a_0 = 0.5 if window_fun in ["hann", "hanning"] else 0.53836 window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold) elif window_fun == "ellipse": - window = xp.sum(kspace_loc**2 / xp.asarray(threshold) ** 2, axis=1) <= 1 + window = ( + xp.sum(kspace_loc ** 2 / xp.asarray(threshold) ** 2, axis=1) <= 1 + ) else: raise ValueError("Unsupported window function.") data_thresholded = window * kspace_data diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index d9800dcf..88001f6b 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -107,7 +107,7 @@ class or instance of class if args or kwargs are given. # if autograd: if isinstance(operator, FourierOperatorBase): operator = operator.make_autograd(wrt_data, wrt_traj) - elif(wrt_data or wrt_traj): # partial + elif wrt_data or wrt_traj: # partial operator = partial(operator.with_autograd, wrt_data, wrt_traj) return operator diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 2f87056c..840ff883 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -103,12 +103,12 @@ def _make_plan_grad(self, **kwargs): self._set_pts(typ="grad") def _set_pts(self, typ): - plan = self.grad_plan if typ == "grad" else self.plans[typ] - plan.setpts( - cp.array(self.samples[:, 0], copy=False), - cp.array(self.samples[:, 1], copy=False), - cp.array(self.samples[:, 2], copy=False) if self.ndim == 3 else None, - ) + plan = self.grad_plan if typ == "grad" else self.plans[typ] + plan.setpts( + cp.array(self.samples[:, 0], copy=False), + cp.array(self.samples[:, 1], copy=False), + cp.array(self.samples[:, 2], copy=False) if self.ndim == 3 else None, + ) def _destroy_plan(self, typ): if self.plans[typ] is not None: diff --git a/src/mrinufft/operators/interfaces/finufft.py b/src/mrinufft/operators/interfaces/finufft.py index 7509ce14..cae64965 100644 --- a/src/mrinufft/operators/interfaces/finufft.py +++ b/src/mrinufft/operators/interfaces/finufft.py @@ -66,7 +66,6 @@ def _set_pts(self, typ): fpts_axes[i] = np.array(self.samples[:, i], dtype=self.samples.dtype) plan = self.grad_plan if typ == "grad" else self.plans[typ] plan.setpts(*fpts_axes) - def adj_op(self, coeffs_data, grid_data): """Type 1 transform. Non Uniform to Uniform.""" diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 11fe488a..6847c87e 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -15,8 +15,8 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): n = np.prod(shape) ndim = len(shape) dtype = module.complex64 - device = getattr(ktraj, 'device', None) - + device = getattr(ktraj, "device", None) + r = [module.linspace(-s / 2, s / 2 - 1, s) for s in shape] if module.__name__ == "torch": r = [x.to(device) for x in r] @@ -24,16 +24,19 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): grid_r = module.meshgrid(*r, indexing="ij") grid_r = module.reshape(module.stack(grid_r), (ndim, n)) traj_grid = module.matmul(ktraj, grid_r) - matrix = module.exp(-2j * module.pi * traj_grid).to(dtype).to(device).clone() if module.__name__ == "torch" else ( - module.exp(-2j * module.pi * traj_grid, dtype=dtype) + matrix = ( + module.exp(-2j * module.pi * traj_grid).to(dtype).to(device).clone() + if module.__name__ == "torch" + else (module.exp(-2j * module.pi * traj_grid, dtype=dtype)) ) - + if normalize: norm_factor = ( - module.sqrt(module.prod(module.tensor(shape, device=device))) * module.pow(module.sqrt(module.tensor(2, device=device)), ndim) - if module.__name__ == "torch" else ( - module.sqrt(module.prod(shape)) * module.power(module.sqrt(2), ndim) - )) + module.sqrt(module.prod(module.tensor(shape, device=device))) + * module.pow(module.sqrt(module.tensor(2, device=device)), ndim) + if module.__name__ == "torch" + else (module.sqrt(module.prod(shape)) * module.power(module.sqrt(2), ndim)) + ) matrix /= norm_factor return matrix diff --git a/src/mrinufft/trajectories/maths.py b/src/mrinufft/trajectories/maths.py index a413df2c..631b05ab 100644 --- a/src/mrinufft/trajectories/maths.py +++ b/src/mrinufft/trajectories/maths.py @@ -187,19 +187,19 @@ def Ra(vector, theta): return np.array( [ [ - cos_t + v_x**2 * (1 - cos_t), + cos_t + v_x ** 2 * (1 - cos_t), v_x * v_y * (1 - cos_t) + v_z * sin_t, v_x * v_z * (1 - cos_t) - v_y * sin_t, ], [ v_y * v_x * (1 - cos_t) - v_z * sin_t, - cos_t + v_y**2 * (1 - cos_t), + cos_t + v_y ** 2 * (1 - cos_t), v_y * v_z * (1 - cos_t) + v_x * sin_t, ], [ v_z * v_x * (1 - cos_t) + v_y * sin_t, v_z * v_y * (1 - cos_t) - v_x * sin_t, - cos_t + v_z**2 * (1 - cos_t), + cos_t + v_z ** 2 * (1 - cos_t), ], ] ) @@ -232,7 +232,7 @@ def _is_perfect_square(n): r = int(np.sqrt(n)) return r * r == n - return _is_perfect_square(5 * n**2 + 4) or _is_perfect_square(5 * n**2 - 4) + return _is_perfect_square(5 * n ** 2 + 4) or _is_perfect_square(5 * n ** 2 - 4) def get_closest_fibonacci_number(x): diff --git a/src/mrinufft/trajectories/tools.py b/src/mrinufft/trajectories/tools.py index 316190e9..31d2ad1b 100644 --- a/src/mrinufft/trajectories/tools.py +++ b/src/mrinufft/trajectories/tools.py @@ -307,7 +307,7 @@ def stack_spherically( # Attribute shots to stacks following density proportional to surface Nc_per_stack = np.ones(nb_stacks).astype(int) - density = radii**2 # simplified version + density = radii ** 2 # simplified version for _ in range(Nc - nb_stacks): idx = np.argmax(density / Nc_per_stack) Nc_per_stack[idx] += 1 @@ -403,7 +403,7 @@ def shellify( ) # Carve upper hemisphere from trajectory - z_coords = KMAX**2 - shell_upper[..., 0] ** 2 - shell_upper[..., 1] ** 2 + z_coords = KMAX ** 2 - shell_upper[..., 0] ** 2 - shell_upper[..., 1] ** 2 z_signs = np.sign(z_coords) shell_upper[..., 2] += z_signs * np.sqrt(np.abs(z_coords)) diff --git a/src/mrinufft/trajectories/trajectory3D.py b/src/mrinufft/trajectories/trajectory3D.py index b67d7833..9b261e04 100644 --- a/src/mrinufft/trajectories/trajectory3D.py +++ b/src/mrinufft/trajectories/trajectory3D.py @@ -239,7 +239,7 @@ def initialize_3D_wave_caipi( elif packing == Packings.CIRCLE: positions = [[0, 0]] counter = 0 - while len(positions) < side**2: + while len(positions) < side ** 2: counter += 1 perimeter = 2 * np.pi * counter nb_shots = int(np.trunc(perimeter)) @@ -352,11 +352,11 @@ def initialize_3D_seiffert_spiral( """ # Normalize ellipses integrations by the requested period spiral = np.zeros((1, Ns // (1 + in_out), 3)) - period = 4 * ellipk(curve_index**2) + period = 4 * ellipk(curve_index ** 2) times = np.linspace(0, nb_revolutions * period, Ns // (1 + in_out), endpoint=False) # Initialize first shot - jacobi = ellipj(times, curve_index**2) + jacobi = ellipj(times, curve_index ** 2) spiral[0, :, 0] = jacobi[0] * np.cos(curve_index * times) spiral[0, :, 1] = jacobi[0] * np.sin(curve_index * times) spiral[0, :, 2] = jacobi[1] @@ -654,7 +654,7 @@ def initialize_3D_seiffert_shells( Nc_per_shell[idx] += 1 # Normalize ellipses integrations by the requested period - period = 4 * ellipk(curve_index**2) + period = 4 * ellipk(curve_index ** 2) times = np.linspace(0, nb_revolutions * period, Ns, endpoint=False) # Create shells one by one @@ -666,7 +666,7 @@ def initialize_3D_seiffert_shells( k0 = radii[i] # Initialize first shot - jacobi = ellipj(times, curve_index**2) + jacobi = ellipj(times, curve_index ** 2) trajectory[count, :, 0] = k0 * jacobi[0] * np.cos(curve_index * times) trajectory[count, :, 1] = k0 * jacobi[0] * np.sin(curve_index * times) trajectory[count, :, 2] = k0 * jacobi[1] From c52dc2fe55f8eb6875023d16a9fac8c3315075c9 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Mon, 3 Jun 2024 16:32:56 +0200 Subject: [PATCH 33/45] style check --- examples/example_density.py | 2 +- src/mrinufft/density/geometry_based.py | 2 +- src/mrinufft/extras/smaps.py | 4 +--- src/mrinufft/trajectories/maths.py | 8 ++++---- src/mrinufft/trajectories/tools.py | 4 ++-- src/mrinufft/trajectories/trajectory3D.py | 10 +++++----- 6 files changed, 14 insertions(+), 16 deletions(-) diff --git a/examples/example_density.py b/examples/example_density.py index 0497b4a8..20f45b3c 100644 --- a/examples/example_density.py +++ b/examples/example_density.py @@ -110,7 +110,7 @@ # %% flat_traj = traj.reshape(-1, 2) -weights = np.sqrt(np.sum(flat_traj ** 2, axis=1)) +weights = np.sqrt(np.sum(flat_traj**2, axis=1)) nufft = get_operator("finufft")(traj, shape=mri_2D.shape, density=weights) adjoint_manual = nufft.adj_op(kspace) fig, axs = plt.subplots(1, 3, figsize=(15, 5)) diff --git a/src/mrinufft/density/geometry_based.py b/src/mrinufft/density/geometry_based.py index cb091831..4dc0ecc5 100644 --- a/src/mrinufft/density/geometry_based.py +++ b/src/mrinufft/density/geometry_based.py @@ -87,7 +87,7 @@ def voronoi_unique(traj, *args, **kwargs): # For edge point (infinite voronoi cells) we extrapolate from neighbours # Initial implementation in Jeff Fessler's MIRT - rho = np.sum(traj ** 2, axis=1) + rho = np.sum(traj**2, axis=1) igood = (rho > 0.6 * np.max(rho)) & ~np.isinf(wi) if len(igood) < 10: print("dubious extrapolation with", len(igood), "points") diff --git a/src/mrinufft/extras/smaps.py b/src/mrinufft/extras/smaps.py index fadb365e..2e2ed4af 100644 --- a/src/mrinufft/extras/smaps.py +++ b/src/mrinufft/extras/smaps.py @@ -87,9 +87,7 @@ def _extract_kspace_center( a_0 = 0.5 if window_fun in ["hann", "hanning"] else 0.53836 window = a_0 + (1 - a_0) * xp.cos(xp.pi * radius / threshold) elif window_fun == "ellipse": - window = ( - xp.sum(kspace_loc ** 2 / xp.asarray(threshold) ** 2, axis=1) <= 1 - ) + window = xp.sum(kspace_loc**2 / xp.asarray(threshold) ** 2, axis=1) <= 1 else: raise ValueError("Unsupported window function.") data_thresholded = window * kspace_data diff --git a/src/mrinufft/trajectories/maths.py b/src/mrinufft/trajectories/maths.py index 631b05ab..a413df2c 100644 --- a/src/mrinufft/trajectories/maths.py +++ b/src/mrinufft/trajectories/maths.py @@ -187,19 +187,19 @@ def Ra(vector, theta): return np.array( [ [ - cos_t + v_x ** 2 * (1 - cos_t), + cos_t + v_x**2 * (1 - cos_t), v_x * v_y * (1 - cos_t) + v_z * sin_t, v_x * v_z * (1 - cos_t) - v_y * sin_t, ], [ v_y * v_x * (1 - cos_t) - v_z * sin_t, - cos_t + v_y ** 2 * (1 - cos_t), + cos_t + v_y**2 * (1 - cos_t), v_y * v_z * (1 - cos_t) + v_x * sin_t, ], [ v_z * v_x * (1 - cos_t) + v_y * sin_t, v_z * v_y * (1 - cos_t) - v_x * sin_t, - cos_t + v_z ** 2 * (1 - cos_t), + cos_t + v_z**2 * (1 - cos_t), ], ] ) @@ -232,7 +232,7 @@ def _is_perfect_square(n): r = int(np.sqrt(n)) return r * r == n - return _is_perfect_square(5 * n ** 2 + 4) or _is_perfect_square(5 * n ** 2 - 4) + return _is_perfect_square(5 * n**2 + 4) or _is_perfect_square(5 * n**2 - 4) def get_closest_fibonacci_number(x): diff --git a/src/mrinufft/trajectories/tools.py b/src/mrinufft/trajectories/tools.py index 31d2ad1b..316190e9 100644 --- a/src/mrinufft/trajectories/tools.py +++ b/src/mrinufft/trajectories/tools.py @@ -307,7 +307,7 @@ def stack_spherically( # Attribute shots to stacks following density proportional to surface Nc_per_stack = np.ones(nb_stacks).astype(int) - density = radii ** 2 # simplified version + density = radii**2 # simplified version for _ in range(Nc - nb_stacks): idx = np.argmax(density / Nc_per_stack) Nc_per_stack[idx] += 1 @@ -403,7 +403,7 @@ def shellify( ) # Carve upper hemisphere from trajectory - z_coords = KMAX ** 2 - shell_upper[..., 0] ** 2 - shell_upper[..., 1] ** 2 + z_coords = KMAX**2 - shell_upper[..., 0] ** 2 - shell_upper[..., 1] ** 2 z_signs = np.sign(z_coords) shell_upper[..., 2] += z_signs * np.sqrt(np.abs(z_coords)) diff --git a/src/mrinufft/trajectories/trajectory3D.py b/src/mrinufft/trajectories/trajectory3D.py index 9b261e04..b67d7833 100644 --- a/src/mrinufft/trajectories/trajectory3D.py +++ b/src/mrinufft/trajectories/trajectory3D.py @@ -239,7 +239,7 @@ def initialize_3D_wave_caipi( elif packing == Packings.CIRCLE: positions = [[0, 0]] counter = 0 - while len(positions) < side ** 2: + while len(positions) < side**2: counter += 1 perimeter = 2 * np.pi * counter nb_shots = int(np.trunc(perimeter)) @@ -352,11 +352,11 @@ def initialize_3D_seiffert_spiral( """ # Normalize ellipses integrations by the requested period spiral = np.zeros((1, Ns // (1 + in_out), 3)) - period = 4 * ellipk(curve_index ** 2) + period = 4 * ellipk(curve_index**2) times = np.linspace(0, nb_revolutions * period, Ns // (1 + in_out), endpoint=False) # Initialize first shot - jacobi = ellipj(times, curve_index ** 2) + jacobi = ellipj(times, curve_index**2) spiral[0, :, 0] = jacobi[0] * np.cos(curve_index * times) spiral[0, :, 1] = jacobi[0] * np.sin(curve_index * times) spiral[0, :, 2] = jacobi[1] @@ -654,7 +654,7 @@ def initialize_3D_seiffert_shells( Nc_per_shell[idx] += 1 # Normalize ellipses integrations by the requested period - period = 4 * ellipk(curve_index ** 2) + period = 4 * ellipk(curve_index**2) times = np.linspace(0, nb_revolutions * period, Ns, endpoint=False) # Create shells one by one @@ -666,7 +666,7 @@ def initialize_3D_seiffert_shells( k0 = radii[i] # Initialize first shot - jacobi = ellipj(times, curve_index ** 2) + jacobi = ellipj(times, curve_index**2) trajectory[count, :, 0] = k0 * jacobi[0] * np.cos(curve_index * times) trajectory[count, :, 1] = k0 * jacobi[0] * np.sin(curve_index * times) trajectory[count, :, 2] = k0 * jacobi[1] From c279683f698b7eb9ad03057b9a7b6bc99c3e0e16 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Tue, 4 Jun 2024 11:34:55 +0200 Subject: [PATCH 34/45] update for PAC'S comments --- src/mrinufft/operators/autodiff.py | 28 +++++------ src/mrinufft/operators/base.py | 15 ++++-- .../operators/interfaces/cufinufft.py | 1 + .../operators/interfaces/nudft_numpy.py | 48 ++++++++++++------- tests/case_trajectories.py | 2 +- 5 files changed, 57 insertions(+), 37 deletions(-) diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index 012fcb18..87eecc0a 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -20,7 +20,7 @@ class _NUFFT_OP(torch.autograd.Function): @staticmethod def forward(ctx, x, traj, nufft_op): """Forward image -> k-space.""" - ctx.save_for_backward(x, traj) # nufft_op.samples => traj + ctx.save_for_backward(x, traj) ctx.nufft_op = nufft_op return nufft_op.op(x) @@ -28,7 +28,8 @@ def forward(ctx, x, traj, nufft_op): def backward(ctx, dy): """Backward image -> k-space.""" (x, traj) = ctx.saved_tensors - + grad_data = None + grad_traj = None if ctx.nufft_op._grad_wrt_data: grad_data = ctx.nufft_op.adj_op(dy) if ctx.nufft_op._grad_wrt_traj: @@ -46,19 +47,18 @@ def backward(ctx, dy): grid_x = x * grid_r # Element-wise multiplication: x * r nufft_dx_dom = torch.cat( [ - ctx.nufft_op.op(grid_x[:, i : i + 1, :, :]) + ctx.nufft_op.op(grid_x[:, i, :, :]) for i in range(grid_x.size(1)) ], dim=1, - ) # Compute A(x * r) for each channel and concatenate along this dimension + ) grad_traj = torch.transpose( (-1j * torch.conj(dy) * nufft_dx_dom).squeeze(), 0, 1 ).type_as( traj - ) # Compute gradient with respect to trajectory: -i * dy' * A(x * r) - else: - grad_traj = None + ) + return grad_data, grad_traj, None @@ -75,14 +75,11 @@ def forward(ctx, y, traj, nufft_op): @staticmethod def backward(ctx, dx): """Backward kspace -> image.""" - (y, traj) = ctx.saved_tensors # y [1, 256] traj [256, 2] - + (y, traj) = ctx.saved_tensors grad_data = None grad_traj = None - print("In AutoGrad") if ctx.nufft_op._grad_wrt_data: grad_data = ctx.nufft_op.op(dx) - if ctx.nufft_op._grad_wrt_traj: ctx.nufft_op.raw_op.toggle_grad_traj() im_size = dx.size()[2:] @@ -94,12 +91,12 @@ def backward(ctx, dx): for size in im_size ] grid_r = torch.meshgrid(*r, indexing="ij") - grid_r = torch.stack(grid_r, dim=0).type_as(dx)[None, ...] # [1, 2, 16, 16] + grid_r = torch.stack(grid_r, dim=0).type_as(dx)[None, ...] grid_dx = torch.conj(dx) * grid_r inufft_dx_dom = torch.cat( [ - ctx.nufft_op.op(grid_dx[:, i : i + 1, :, :]) + ctx.nufft_op.op(grid_dx[:, i , :, :]) for i in range(grid_dx.size(1)) ], dim=1, @@ -126,10 +123,11 @@ class MRINufftAutoGrad(torch.nn.Module): def __init__(self, nufft_op, wrt_data=True, wrt_traj=False): super().__init__() if (wrt_data or wrt_traj) and nufft_op.squeeze_dims: - raise ValueError("Squeezing dimensions is not " "supported for autodiff.") + raise ValueError("Squeezing dimensions is not supported for autodiff.") + self.nufft_op = nufft_op self.nufft_op._grad_wrt_traj = wrt_traj - if wrt_traj and self.nufft_op.backend != "gpunufft": + if wrt_traj and self.nufft_op.backend in ["finufft", "cufinufft"]: self.nufft_op.raw_op._make_plan_grad() self.nufft_op._grad_wrt_data = wrt_data diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 88001f6b..6ecf6094 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -107,7 +107,7 @@ class or instance of class if args or kwargs are given. # if autograd: if isinstance(operator, FourierOperatorBase): operator = operator.make_autograd(wrt_data, wrt_traj) - elif wrt_data or wrt_traj: # partial + elif wrt_data or wrt_traj: # instance will be created later operator = partial(operator.with_autograd, wrt_data, wrt_traj) return operator @@ -192,13 +192,14 @@ class FourierOperatorBase(ABC): """ interfaces: dict[str, tuple] = {} - + def __init__(self): if not self.available: raise RuntimeError(f"'{self.backend}' backend is not available.") self._smaps = None self._density = None self._n_coils = 1 + self.autograd_available = False def __init_subclass__(cls): """Register the class in the list of available operators.""" @@ -299,6 +300,12 @@ def make_autograd(self, wrt_data=True, wrt_traj=False): variable: , default data variable on which the gradient is computed with respect to. + wrt_data : bool, optional + Whether to compute the gradient with respect to the data, default is True + + wrt_traj : bool, optional + Whether to compute the gradient with respect to the trajectory, default is False + Returns ------- torch.nn.module @@ -309,7 +316,7 @@ def make_autograd(self, wrt_data=True, wrt_traj=False): ValueError If autograd is not available. """ - if not AUTOGRAD_AVAILABLE: + if not AUTOGRAD_AVAILABLE or not self.autograd_available: raise ValueError("Autograd not available, ensure torch is installed.") return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj) @@ -484,7 +491,7 @@ def __repr__(self): ) @classmethod - def with_autograd(cls, wrt_data, wrt_traj, *args, **kwargs): + def with_autograd(cls, wrt_data=True, wrt_traj=False, *args, **kwargs): """Return a Fourier operator with autograd capabilities.""" return cls(*args, **kwargs).make_autograd(wrt_data, wrt_traj) diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 840ff883..72f0d4e1 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -219,6 +219,7 @@ def __init__( self.n_trans = n_trans self.squeeze_dims = squeeze_dims self.n_coils = n_coils + self.autograd_available = True # For now only single precision is supported self.samples = np.asfortranarray( proper_trajectory(samples, normalize="pi").astype(np.float32, copy=False) diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 6847c87e..36c6a2e7 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -9,34 +9,48 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): - """Get the NDFT Fourier Matrix.""" - module = get_array_module(ktraj) + """Generates a Fourier matrix for non-uniform k-space trajectories. + + Parameters + ---------- + ktraj : array_like + The k-space coordinates for the Fourier transformation. + shape : tuple of int + The dimensions of the output Fourier matrix. + dtype : data-type, optional + The data type of the Fourier matrix, default is np.complex64. + normalize : bool, optional + If True, normalizes the matrix to maintain numerical stability. + + Returns + ------- + matrix + The NDFT Fourier Matrix. + """ + xp = get_array_module(ktraj) ktraj = proper_trajectory(ktraj, normalize="unit") n = np.prod(shape) ndim = len(shape) - dtype = module.complex64 + dtype = xp.complex64 device = getattr(ktraj, "device", None) - r = [module.linspace(-s / 2, s / 2 - 1, s) for s in shape] - if module.__name__ == "torch": + r = [xp.linspace(-s / 2, s / 2 - 1, s) for s in shape] + if xp.__name__ == "torch": r = [x.to(device) for x in r] - grid_r = module.meshgrid(*r, indexing="ij") - grid_r = module.reshape(module.stack(grid_r), (ndim, n)) - traj_grid = module.matmul(ktraj, grid_r) + grid_r = xp.meshgrid(*r, indexing="ij") + grid_r = xp.reshape(xp.stack(grid_r), (ndim, n)) + traj_grid = xp.matmul(ktraj, grid_r) matrix = ( - module.exp(-2j * module.pi * traj_grid).to(dtype).to(device).clone() - if module.__name__ == "torch" - else (module.exp(-2j * module.pi * traj_grid, dtype=dtype)) + xp.exp(-2j * xp.pi * traj_grid).to(dtype).to(device).clone() + if xp.__name__ == "torch" + else (xp.exp(-2j * xp.pi * traj_grid, dtype=dtype)) ) if normalize: - norm_factor = ( - module.sqrt(module.prod(module.tensor(shape, device=device))) - * module.pow(module.sqrt(module.tensor(2, device=device)), ndim) - if module.__name__ == "torch" - else (module.sqrt(module.prod(shape)) * module.power(module.sqrt(2), ndim)) - ) + norm_factor = np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), ndim) + if xp.__name__ == "torch": + norm_factor = xp.tensor(norm_factor, device=device) matrix /= norm_factor return matrix diff --git a/tests/case_trajectories.py b/tests/case_trajectories.py index ccee154d..97c2df59 100644 --- a/tests/case_trajectories.py +++ b/tests/case_trajectories.py @@ -52,7 +52,7 @@ def case_nyquist_radial3D(self, Nc=32 * 4, Ns=16, Nr=32 * 4, N=32): return trajectory, (N, N, N) def case_nyquist_radial3D_lowmem(self, Nc=2, Ns=16, Nr=2, N=10): - """Create a 3D radial trajectory.""" + """Create a 3D radial trajectory with low memory.""" trajectory = initialize_2D_radial(Nc, Ns) trajectory = rotate(trajectory, nb_rotations=Nr) return trajectory, (N, N, N) From ebc1f5616cf223500964fd128a91135b43a98730 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Tue, 4 Jun 2024 13:41:25 +0200 Subject: [PATCH 35/45] fix autograd_available --- src/mrinufft/operators/base.py | 4 ++-- src/mrinufft/operators/interfaces/cufinufft.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 6ecf6094..115390b0 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -192,14 +192,14 @@ class FourierOperatorBase(ABC): """ interfaces: dict[str, tuple] = {} - + autograd_available = False + def __init__(self): if not self.available: raise RuntimeError(f"'{self.backend}' backend is not available.") self._smaps = None self._density = None self._n_coils = 1 - self.autograd_available = False def __init_subclass__(cls): """Register the class in the list of available operators.""" diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 72f0d4e1..630a68ab 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -189,7 +189,8 @@ class MRICufiNUFFT(FourierOperatorBase): backend = "cufinufft" available = CUFINUFFT_AVAILABLE and CUPY_AVAILABLE - + autograd_available = True + def __init__( self, samples, From 9f7c5d705712c4eaae89eef02d078be2d3f4d9ce Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Tue, 4 Jun 2024 16:07:58 +0200 Subject: [PATCH 36/45] fix autograd_available --- src/mrinufft/operators/interfaces/finufft.py | 3 ++- src/mrinufft/operators/interfaces/gpunufft.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mrinufft/operators/interfaces/finufft.py b/src/mrinufft/operators/interfaces/finufft.py index cae64965..1e5bc6fd 100644 --- a/src/mrinufft/operators/interfaces/finufft.py +++ b/src/mrinufft/operators/interfaces/finufft.py @@ -117,7 +117,8 @@ class MRIfinufft(FourierOperatorCPU): backend = "finufft" available = FINUFFT_AVAILABLE - + autograd_available = True + def __init__( self, samples, diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index d976bc5c..dc145f27 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -356,7 +356,8 @@ class MRIGpuNUFFT(FourierOperatorBase): backend = "gpunufft" available = GPUNUFFT_AVAILABLE and CUPY_AVAILABLE - + autograd_available = True + def __init__( self, samples, From cfa781bb96f0365c2a4c53409d0d920c1cca1df6 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Tue, 4 Jun 2024 16:37:45 +0200 Subject: [PATCH 37/45] fix style checking --- src/mrinufft/operators/autodiff.py | 32 +++++++------------ src/mrinufft/operators/base.py | 8 ++--- .../operators/interfaces/cufinufft.py | 4 +-- src/mrinufft/operators/interfaces/finufft.py | 2 +- src/mrinufft/operators/interfaces/gpunufft.py | 2 +- .../operators/interfaces/nudft_numpy.py | 2 +- tests/case_trajectories.py | 2 +- 7 files changed, 22 insertions(+), 30 deletions(-) diff --git a/src/mrinufft/operators/autodiff.py b/src/mrinufft/operators/autodiff.py index 87eecc0a..3353535e 100644 --- a/src/mrinufft/operators/autodiff.py +++ b/src/mrinufft/operators/autodiff.py @@ -20,7 +20,7 @@ class _NUFFT_OP(torch.autograd.Function): @staticmethod def forward(ctx, x, traj, nufft_op): """Forward image -> k-space.""" - ctx.save_for_backward(x, traj) + ctx.save_for_backward(x, traj) ctx.nufft_op = nufft_op return nufft_op.op(x) @@ -28,8 +28,8 @@ def forward(ctx, x, traj, nufft_op): def backward(ctx, dy): """Backward image -> k-space.""" (x, traj) = ctx.saved_tensors - grad_data = None - grad_traj = None + grad_data = None + grad_traj = None if ctx.nufft_op._grad_wrt_data: grad_data = ctx.nufft_op.adj_op(dy) if ctx.nufft_op._grad_wrt_traj: @@ -46,19 +46,14 @@ def backward(ctx, dy): grid_x = x * grid_r # Element-wise multiplication: x * r nufft_dx_dom = torch.cat( - [ - ctx.nufft_op.op(grid_x[:, i, :, :]) - for i in range(grid_x.size(1)) - ], + [ctx.nufft_op.op(grid_x[:, i, :, :]) for i in range(grid_x.size(1))], dim=1, - ) + ) grad_traj = torch.transpose( (-1j * torch.conj(dy) * nufft_dx_dom).squeeze(), 0, 1 - ).type_as( - traj - ) - + ).type_as(traj) + return grad_data, grad_traj, None @@ -75,7 +70,7 @@ def forward(ctx, y, traj, nufft_op): @staticmethod def backward(ctx, dx): """Backward kspace -> image.""" - (y, traj) = ctx.saved_tensors + (y, traj) = ctx.saved_tensors grad_data = None grad_traj = None if ctx.nufft_op._grad_wrt_data: @@ -91,14 +86,11 @@ def backward(ctx, dx): for size in im_size ] grid_r = torch.meshgrid(*r, indexing="ij") - grid_r = torch.stack(grid_r, dim=0).type_as(dx)[None, ...] + grid_r = torch.stack(grid_r, dim=0).type_as(dx)[None, ...] grid_dx = torch.conj(dx) * grid_r inufft_dx_dom = torch.cat( - [ - ctx.nufft_op.op(grid_dx[:, i , :, :]) - for i in range(grid_dx.size(1)) - ], + [ctx.nufft_op.op(grid_dx[:, i, :, :]) for i in range(grid_dx.size(1))], dim=1, ) @@ -123,11 +115,11 @@ class MRINufftAutoGrad(torch.nn.Module): def __init__(self, nufft_op, wrt_data=True, wrt_traj=False): super().__init__() if (wrt_data or wrt_traj) and nufft_op.squeeze_dims: - raise ValueError("Squeezing dimensions is not supported for autodiff.") + raise ValueError("Squeezing dimensions is not supported for autodiff.") self.nufft_op = nufft_op self.nufft_op._grad_wrt_traj = wrt_traj - if wrt_traj and self.nufft_op.backend in ["finufft", "cufinufft"]: + if wrt_traj and self.nufft_op.backend in ["finufft", "cufinufft"]: self.nufft_op.raw_op._make_plan_grad() self.nufft_op._grad_wrt_data = wrt_data diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 115390b0..c6cf4143 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -301,10 +301,10 @@ def make_autograd(self, wrt_data=True, wrt_traj=False): variable on which the gradient is computed with respect to. wrt_data : bool, optional - Whether to compute the gradient with respect to the data, default is True - + If the gradient with respect to the data is computed, default is true + wrt_traj : bool, optional - Whether to compute the gradient with respect to the trajectory, default is False + If the gradient with respect to the trajectory is computed, default is false Returns ------- @@ -491,7 +491,7 @@ def __repr__(self): ) @classmethod - def with_autograd(cls, wrt_data=True, wrt_traj=False, *args, **kwargs): + def with_autograd(cls, wrt_data=True, wrt_traj=False, *args, **kwargs): """Return a Fourier operator with autograd capabilities.""" return cls(*args, **kwargs).make_autograd(wrt_data, wrt_traj) diff --git a/src/mrinufft/operators/interfaces/cufinufft.py b/src/mrinufft/operators/interfaces/cufinufft.py index 630a68ab..f5703547 100644 --- a/src/mrinufft/operators/interfaces/cufinufft.py +++ b/src/mrinufft/operators/interfaces/cufinufft.py @@ -189,8 +189,8 @@ class MRICufiNUFFT(FourierOperatorBase): backend = "cufinufft" available = CUFINUFFT_AVAILABLE and CUPY_AVAILABLE - autograd_available = True - + autograd_available = True + def __init__( self, samples, diff --git a/src/mrinufft/operators/interfaces/finufft.py b/src/mrinufft/operators/interfaces/finufft.py index 1e5bc6fd..21e9618d 100644 --- a/src/mrinufft/operators/interfaces/finufft.py +++ b/src/mrinufft/operators/interfaces/finufft.py @@ -118,7 +118,7 @@ class MRIfinufft(FourierOperatorCPU): backend = "finufft" available = FINUFFT_AVAILABLE autograd_available = True - + def __init__( self, samples, diff --git a/src/mrinufft/operators/interfaces/gpunufft.py b/src/mrinufft/operators/interfaces/gpunufft.py index dc145f27..a2091ce6 100644 --- a/src/mrinufft/operators/interfaces/gpunufft.py +++ b/src/mrinufft/operators/interfaces/gpunufft.py @@ -357,7 +357,7 @@ class MRIGpuNUFFT(FourierOperatorBase): backend = "gpunufft" available = GPUNUFFT_AVAILABLE and CUPY_AVAILABLE autograd_available = True - + def __init__( self, samples, diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 36c6a2e7..5398f347 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -9,7 +9,7 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): - """Generates a Fourier matrix for non-uniform k-space trajectories. + """Get the NDFT Fourier Matrix. Parameters ---------- diff --git a/tests/case_trajectories.py b/tests/case_trajectories.py index 97c2df59..9f7296bd 100644 --- a/tests/case_trajectories.py +++ b/tests/case_trajectories.py @@ -52,7 +52,7 @@ def case_nyquist_radial3D(self, Nc=32 * 4, Ns=16, Nr=32 * 4, N=32): return trajectory, (N, N, N) def case_nyquist_radial3D_lowmem(self, Nc=2, Ns=16, Nr=2, N=10): - """Create a 3D radial trajectory with low memory.""" + """Create a 3D radial trajectory with low memory.""" trajectory = initialize_2D_radial(Nc, Ns) trajectory = rotate(trajectory, nb_rotations=Nr) return trajectory, (N, N, N) From 6d4e4c58303ec9c44a5d500cc10f52dfff245372 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Tue, 4 Jun 2024 17:47:58 +0200 Subject: [PATCH 38/45] reduce the tolerance in test_autodiff --- tests/operators/test_autodiff.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/operators/test_autodiff.py b/tests/operators/test_autodiff.py index ebdaeb7e..454ca3a7 100644 --- a/tests/operators/test_autodiff.py +++ b/tests/operators/test_autodiff.py @@ -79,7 +79,7 @@ def test_adjoint_and_grad(operator, interface): gradient_nufft_ktraj = torch.autograd.grad( loss_nufft, operator.samples, retain_graph=True )[0] - assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-2) + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-7) # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, ksp_data, retain_graph=True)[0] @@ -121,7 +121,7 @@ def test_forward_and_grad(operator, interface): gradient_nufft_ktraj = torch.autograd.grad( loss_nufft, operator.samples, retain_graph=True )[0] - assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-2) + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-7) # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, img_data, retain_graph=True)[0] From b906980765b027df0a7e8a6ff0f71a9f695c11da Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Tue, 4 Jun 2024 17:59:12 +0200 Subject: [PATCH 39/45] set the proper tolerance in test_autodiff --- tests/operators/test_autodiff.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/operators/test_autodiff.py b/tests/operators/test_autodiff.py index 454ca3a7..ebdaeb7e 100644 --- a/tests/operators/test_autodiff.py +++ b/tests/operators/test_autodiff.py @@ -79,7 +79,7 @@ def test_adjoint_and_grad(operator, interface): gradient_nufft_ktraj = torch.autograd.grad( loss_nufft, operator.samples, retain_graph=True )[0] - assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-7) + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-2) # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, ksp_data, retain_graph=True)[0] @@ -121,7 +121,7 @@ def test_forward_and_grad(operator, interface): gradient_nufft_ktraj = torch.autograd.grad( loss_nufft, operator.samples, retain_graph=True )[0] - assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-7) + assert torch.allclose(gradient_ndft_ktraj, gradient_nufft_ktraj, atol=5e-2) # Check if nufft and ndft are close in the backprop gradient_ndft_kdata = torch.autograd.grad(loss_ndft, img_data, retain_graph=True)[0] From dbce464d8f90f446e34ceb5f1217059fb9b1a6f2 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Thu, 6 Jun 2024 13:44:37 +0200 Subject: [PATCH 40/45] update for comments --- src/mrinufft/operators/base.py | 6 ++++-- src/mrinufft/operators/interfaces/finufft.py | 5 +++-- src/mrinufft/operators/interfaces/nudft_numpy.py | 9 ++++----- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index c6cf4143..4c20df42 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -316,8 +316,10 @@ def make_autograd(self, wrt_data=True, wrt_traj=False): ValueError If autograd is not available. """ - if not AUTOGRAD_AVAILABLE or not self.autograd_available: - raise ValueError("Autograd not available, ensure torch is installed.") + if not AUTOGRAD_AVAILABLE: + raise ValueError("Autograd not available, ensure torch is installed.") + if not self.autograd_available: + raise ValueError("Backend does not support auto-differentiation.") return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj) diff --git a/src/mrinufft/operators/interfaces/finufft.py b/src/mrinufft/operators/interfaces/finufft.py index 21e9618d..90a1a755 100644 --- a/src/mrinufft/operators/interfaces/finufft.py +++ b/src/mrinufft/operators/interfaces/finufft.py @@ -11,6 +11,7 @@ except ImportError: FINUFFT_AVAILABLE = False +DTYPE_R2C = {"float32": "complex64", "float64": "complex128"} class RawFinufftPlan: """Light wrapper around the guru interface of finufft.""" @@ -44,7 +45,7 @@ def _make_plan(self, typ, **kwargs): self.shape, self.n_trans, self.eps, - dtype="complex64" if self.samples.dtype == "float32" else "complex128", + dtype=DTYPE_R2C[str(self.samples.dtype)], **kwargs, ) @@ -54,7 +55,7 @@ def _make_plan_grad(self, **kwargs): self.shape, self.n_trans, self.eps, - dtype="complex64" if self.samples.dtype == "float32" else "complex128", + dtype=DTYPE_R2C[str(self.samples.dtype)], isign=1, **kwargs, ) diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 5398f347..502f75f3 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -41,11 +41,10 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): grid_r = xp.meshgrid(*r, indexing="ij") grid_r = xp.reshape(xp.stack(grid_r), (ndim, n)) traj_grid = xp.matmul(ktraj, grid_r) - matrix = ( - xp.exp(-2j * xp.pi * traj_grid).to(dtype).to(device).clone() - if xp.__name__ == "torch" - else (xp.exp(-2j * xp.pi * traj_grid, dtype=dtype)) - ) + matrix = xp.exp(-2j * xp.pi * traj_grid) + if xp.__name__ == "torch": + matrix.to(dtype=dtype, device=device, copy=True) + if normalize: norm_factor = np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), ndim) From aef0b8f3074243b62c8ace4e9bb00b6c55657874 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Thu, 6 Jun 2024 13:52:01 +0200 Subject: [PATCH 41/45] black style checking --- src/mrinufft/operators/base.py | 8 ++++---- src/mrinufft/operators/interfaces/finufft.py | 1 + src/mrinufft/operators/interfaces/nudft_numpy.py | 7 +++---- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 4c20df42..ea8583ed 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -316,10 +316,10 @@ def make_autograd(self, wrt_data=True, wrt_traj=False): ValueError If autograd is not available. """ - if not AUTOGRAD_AVAILABLE: - raise ValueError("Autograd not available, ensure torch is installed.") - if not self.autograd_available: - raise ValueError("Backend does not support auto-differentiation.") + if not AUTOGRAD_AVAILABLE: + raise ValueError("Autograd not available, ensure torch is installed.") + if not self.autograd_available: + raise ValueError("Backend does not support auto-differentiation.") return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj) diff --git a/src/mrinufft/operators/interfaces/finufft.py b/src/mrinufft/operators/interfaces/finufft.py index 90a1a755..b4b5f807 100644 --- a/src/mrinufft/operators/interfaces/finufft.py +++ b/src/mrinufft/operators/interfaces/finufft.py @@ -13,6 +13,7 @@ DTYPE_R2C = {"float32": "complex64", "float64": "complex128"} + class RawFinufftPlan: """Light wrapper around the guru interface of finufft.""" diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 502f75f3..2c011df4 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -41,10 +41,9 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): grid_r = xp.meshgrid(*r, indexing="ij") grid_r = xp.reshape(xp.stack(grid_r), (ndim, n)) traj_grid = xp.matmul(ktraj, grid_r) - matrix = xp.exp(-2j * xp.pi * traj_grid) - if xp.__name__ == "torch": - matrix.to(dtype=dtype, device=device, copy=True) - + matrix = xp.exp(-2j * xp.pi * traj_grid) + if xp.__name__ == "torch": + matrix.to(dtype=dtype, device=device, copy=True) if normalize: norm_factor = np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), ndim) From 5e591a7c94469789bfce5e0f5a52bf7c3cdd5818 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Thu, 6 Jun 2024 14:03:13 +0200 Subject: [PATCH 42/45] black style checking --- .vscode/settings.json | 7 +++++++ 1 file changed, 7 insertions(+) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..9b388533 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,7 @@ +{ + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true +} \ No newline at end of file From 7ff8061c2c37d1ae809228e6808ef5ee1901151e Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Thu, 6 Jun 2024 14:36:31 +0200 Subject: [PATCH 43/45] fix get_fourier_matrix --- src/mrinufft/operators/interfaces/nudft_numpy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mrinufft/operators/interfaces/nudft_numpy.py b/src/mrinufft/operators/interfaces/nudft_numpy.py index 2c011df4..cc4f4806 100644 --- a/src/mrinufft/operators/interfaces/nudft_numpy.py +++ b/src/mrinufft/operators/interfaces/nudft_numpy.py @@ -43,7 +43,7 @@ def get_fourier_matrix(ktraj, shape, dtype=np.complex64, normalize=False): traj_grid = xp.matmul(ktraj, grid_r) matrix = xp.exp(-2j * xp.pi * traj_grid) if xp.__name__ == "torch": - matrix.to(dtype=dtype, device=device, copy=True) + matrix = matrix.to(dtype=dtype, device=device, copy=True) if normalize: norm_factor = np.sqrt(np.prod(shape)) * np.power(np.sqrt(2), ndim) From 88c374dee37613a83ebf3d36779eec3b22017017 Mon Sep 17 00:00:00 2001 From: Caini Pan <88090141+alineyyy@users.noreply.github.com> Date: Thu, 6 Jun 2024 15:25:08 +0200 Subject: [PATCH 44/45] Delete .vscode directory --- .vscode/settings.json | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 9b388533..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,7 +0,0 @@ -{ - "python.testing.pytestArgs": [ - "tests" - ], - "python.testing.unittestEnabled": false, - "python.testing.pytestEnabled": true -} \ No newline at end of file From a5744b1e3972b2833de4ee53de7cd8153d8864b4 Mon Sep 17 00:00:00 2001 From: Caini PAN Date: Thu, 6 Jun 2024 15:38:29 +0200 Subject: [PATCH 45/45] update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index fba19b76..c3f2bdd3 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,8 @@ dist/ examples/*.ipynb *.xml .coverage* +.vscode + .idea *.log