From 5b9102ccb31936efecc7aa78bc7eccee9fd259ea Mon Sep 17 00:00:00 2001 From: Lena OUDJMAN Date: Thu, 29 Aug 2024 15:33:34 +0200 Subject: [PATCH 01/11] add cg function in base --- src/mrinufft/operators/base.py | 45 ++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 13689dae..65c09efa 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -460,6 +460,51 @@ def get_lipschitz_cst(self, max_iter=10, **kwargs): tmp_op = self return power_method(max_iter, tmp_op) + def cg(self, kspace_data, x_init=None, num_iter=10, tol=1e-4): + """ + Perform conjugate gradient (CG) optimization for image reconstruction. + + The image is updated using the gradient of a data consistency term, + and a velocity vector is used to accelerate convergence. + + Parameters + ---------- + kspace_data : numpy.ndarray + The k-space data to be used for image reconstruction. + + x_init : numpy.ndarray, optional + An initial guess for the image. If None, an image of zeros with the same + shape as the expected output is used. Default is None. + + num_iter : int, optional + The maximum number of iterations to perform. Default is 10. + + tol : float, optional + The tolerance for convergence. If the norm of the gradient falls below this + value or the dot product between the image and k-space data is non-positive, + the iterations stop. Default is 1e-4. + + Returns + ------- + image : numpy.ndarray + The reconstructed image after the optimization process. + """ + Lipschitz_cst = self.get_lipschitz_cst() + image = np.zeros(self.shape) if x_init is None else x_init + velocity = np.zeros_like(image) + + for _ in range(num_iter): + if np.real(np.dot(image, kspace_data)) <= 0: + break + + if np.sqrt(image) < tol: + break + + grad = self.data_consistency(image, kspace_data) + velocity = tol * velocity + grad * Lipschitz_cst + image = image - velocity + return image + @property def uses_sense(self): """Return True if the operator uses sensitivity maps.""" From 436ad798411fd4e7a87467fe077aa88fa2185f1f Mon Sep 17 00:00:00 2001 From: Lena OUDJMAN Date: Thu, 29 Aug 2024 16:20:11 +0200 Subject: [PATCH 02/11] add test cg --- tests/test_cg.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 tests/test_cg.py diff --git a/tests/test_cg.py b/tests/test_cg.py new file mode 100644 index 00000000..e301e4f3 --- /dev/null +++ b/tests/test_cg.py @@ -0,0 +1,58 @@ +"""Test for the cg function.""" + +import numpy as np +import pytest +from pytest_cases import parametrize_with_cases, parametrize, fixture +from mrinufft import get_operator +from case_trajectories import CasesTrajectories + +from helpers import ( + kspace_from_op, + image_from_op, + to_interface, + from_interface, + param_array_interface, +) + + +@fixture(scope="module") +@parametrize( + "backend", + ["torchkbnufft-gpu"], +) +@parametrize_with_cases("kspace_locs, shape", cases=CasesTrajectories) +def operator( + request, + backend="pynfft", + kspace_locs=None, + shape=None, + n_coils=1, +): + """Generate an operator.""" + if backend in ["pynfft", "sigpy"] and kspace_locs.shape[-1] == 3: + pytest.skip("3D for slow cpu is not tested") + return get_operator(backend)(kspace_locs, shape, n_coils=n_coils, smaps=None) + + +@fixture(scope="module") +def image_data(operator): + """Generate a random image. Remains constant for the module.""" + return image_from_op(operator) + + +@fixture(scope="module") +def kspace_data(operator): + """Generate a random kspace. Remains constant for the module.""" + return kspace_from_op(operator) + + +@param_array_interface +def test_cg(operator, array_interface, image_data): + """Compare the interface to the raw NUDFT implementation.""" + image_data_ = to_interface(image_data, array_interface) + kspace_nufft = operator.op(image_data_).squeeze() + + image_cg = operator.cg(kspace_nufft) + kspace_cg = operator.op(image_cg).squeeze() + assert np.allclose(kspace_nufft, kspace_cg, atol=1e-5, rtol=1e-5) + From ed398a837c356e09f5d1eea1b4527ca80e7d3b02 Mon Sep 17 00:00:00 2001 From: Lena OUDJMAN Date: Thu, 29 Aug 2024 16:32:14 +0200 Subject: [PATCH 03/11] fix cg function 1 --- src/mrinufft/operators/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 65c09efa..8b9b0e40 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -460,6 +460,7 @@ def get_lipschitz_cst(self, max_iter=10, **kwargs): tmp_op = self return power_method(max_iter, tmp_op) + @with_numpy def cg(self, kspace_data, x_init=None, num_iter=10, tol=1e-4): """ Perform conjugate gradient (CG) optimization for image reconstruction. @@ -494,10 +495,10 @@ def cg(self, kspace_data, x_init=None, num_iter=10, tol=1e-4): velocity = np.zeros_like(image) for _ in range(num_iter): - if np.real(np.dot(image, kspace_data)) <= 0: + if (np.real(np.dot(image, image.T)) <= 0).any(): break - if np.sqrt(image) < tol: + if (np.sqrt(image) < tol).any(): break grad = self.data_consistency(image, kspace_data) From 39d8e70b45156166f1540e6eca2e9795f2ea0dfc Mon Sep 17 00:00:00 2001 From: Lena OUDJMAN Date: Fri, 30 Aug 2024 16:33:00 +0200 Subject: [PATCH 04/11] some changes --- src/mrinufft/operators/base.py | 14 ++++++-------- tests/test_cg.py | 23 ++++++++++------------- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 8b9b0e40..eae502f2 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -491,21 +491,19 @@ def cg(self, kspace_data, x_init=None, num_iter=10, tol=1e-4): The reconstructed image after the optimization process. """ Lipschitz_cst = self.get_lipschitz_cst() - image = np.zeros(self.shape) if x_init is None else x_init + image = np.zeros(self.shape, dtype=type(kspace_data[0])) if x_init is None else x_init velocity = np.zeros_like(image) for _ in range(num_iter): - if (np.real(np.dot(image, image.T)) <= 0).any(): - break - - if (np.sqrt(image) < tol).any(): - break - grad = self.data_consistency(image, kspace_data) velocity = tol * velocity + grad * Lipschitz_cst + + if np.linalg.norm(grad) < tol: + break image = image - velocity return image - + + @property def uses_sense(self): """Return True if the operator uses sensitivity maps.""" diff --git a/tests/test_cg.py b/tests/test_cg.py index e301e4f3..5873b6a8 100644 --- a/tests/test_cg.py +++ b/tests/test_cg.py @@ -7,18 +7,16 @@ from case_trajectories import CasesTrajectories from helpers import ( - kspace_from_op, image_from_op, - to_interface, - from_interface, param_array_interface, ) +from tests.helpers.asserts import assert_almost_allclose @fixture(scope="module") @parametrize( "backend", - ["torchkbnufft-gpu"], + ["finufft", "torchkbnufft-cpu"], ) @parametrize_with_cases("kspace_locs, shape", cases=CasesTrajectories) def operator( @@ -40,19 +38,18 @@ def image_data(operator): return image_from_op(operator) -@fixture(scope="module") -def kspace_data(operator): - """Generate a random kspace. Remains constant for the module.""" - return kspace_from_op(operator) - - @param_array_interface def test_cg(operator, array_interface, image_data): """Compare the interface to the raw NUDFT implementation.""" - image_data_ = to_interface(image_data, array_interface) - kspace_nufft = operator.op(image_data_).squeeze() + kspace_nufft = operator.op(image_data).squeeze() image_cg = operator.cg(kspace_nufft) kspace_cg = operator.op(image_cg).squeeze() - assert np.allclose(kspace_nufft, kspace_cg, atol=1e-5, rtol=1e-5) + assert_almost_allclose( + kspace_nufft, + kspace_cg, + atol=1e-1, + rtol=1e-1, + mismatch=20, + ) From 98605687f47f7228a6d2c077bb50f859cac71ad9 Mon Sep 17 00:00:00 2001 From: Lena OUDJMAN Date: Wed, 4 Sep 2024 10:15:57 +0200 Subject: [PATCH 05/11] fix minore : import and style --- src/mrinufft/operators/base.py | 7 +++++-- tests/test_cg.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index f92a3901..6bbf4721 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -528,7 +528,11 @@ def cg(self, kspace_data, x_init=None, num_iter=10, tol=1e-4): The reconstructed image after the optimization process. """ Lipschitz_cst = self.get_lipschitz_cst() - image = np.zeros(self.shape, dtype=type(kspace_data[0])) if x_init is None else x_init + image = ( + np.zeros(self.shape, dtype=type(kspace_data[0])) + if x_init is None + else x_init + ) velocity = np.zeros_like(image) for _ in range(num_iter): @@ -540,7 +544,6 @@ def cg(self, kspace_data, x_init=None, num_iter=10, tol=1e-4): image = image - velocity return image - @property def uses_sense(self): """Return True if the operator uses sensitivity maps.""" diff --git a/tests/test_cg.py b/tests/test_cg.py index 5873b6a8..b73b671d 100644 --- a/tests/test_cg.py +++ b/tests/test_cg.py @@ -10,7 +10,7 @@ image_from_op, param_array_interface, ) -from tests.helpers.asserts import assert_almost_allclose +from helpers import assert_almost_allclose @fixture(scope="module") From c34f0eaef717121071616c72cba2d2ebe30d5257 Mon Sep 17 00:00:00 2001 From: Lena OUDJMAN Date: Tue, 17 Sep 2024 09:15:10 +0200 Subject: [PATCH 06/11] mv cg function, fix test --- src/mrinufft/extras/gradient.py | 51 +++++++++++++++++++++++++++++++++ src/mrinufft/operators/base.py | 47 ------------------------------ tests/test_cg.py | 7 +++-- 3 files changed, 55 insertions(+), 50 deletions(-) create mode 100644 src/mrinufft/extras/gradient.py diff --git a/src/mrinufft/extras/gradient.py b/src/mrinufft/extras/gradient.py new file mode 100644 index 00000000..f435095b --- /dev/null +++ b/src/mrinufft/extras/gradient.py @@ -0,0 +1,51 @@ +import numpy as np + +from mrinufft.operators.base import with_numpy + +@with_numpy +def cg(operator, kspace_data, x_init=None, num_iter=10, tol=1e-4): + """ + Perform conjugate gradient (CG) optimization for image reconstruction. + + The image is updated using the gradient of a data consistency term, + and a velocity vector is used to accelerate convergence. + + Parameters + ---------- + kspace_data : numpy.ndarray + The k-space data to be used for image reconstruction. + + x_init : numpy.ndarray, optional + An initial guess for the image. If None, an image of zeros with the same + shape as the expected output is used. Default is None. + + num_iter : int, optional + The maximum number of iterations to perform. Default is 10. + + tol : float, optional + The tolerance for convergence. If the norm of the gradient falls below this + value or the dot product between the image and k-space data is non-positive, + the iterations stop. Default is 1e-4. + + Returns + ------- + image : numpy.ndarray + The reconstructed image after the optimization process. + """ + + Lipschitz_cst = operator.get_lipschitz_cst() + image = ( + np.zeros(operator.shape, dtype=type(kspace_data[0])) + if x_init is None + else x_init + ) + velocity = np.zeros_like(image) + + for _ in range(num_iter): + grad = operator.data_consistency(image, kspace_data) + velocity = tol * velocity + grad * Lipschitz_cst + if np.linalg.norm(grad) < tol: + break + image = image - velocity + + return image \ No newline at end of file diff --git a/src/mrinufft/operators/base.py b/src/mrinufft/operators/base.py index 6bbf4721..0c86e8bb 100644 --- a/src/mrinufft/operators/base.py +++ b/src/mrinufft/operators/base.py @@ -497,53 +497,6 @@ def get_lipschitz_cst(self, max_iter=10, **kwargs): tmp_op = self return power_method(max_iter, tmp_op) - @with_numpy - def cg(self, kspace_data, x_init=None, num_iter=10, tol=1e-4): - """ - Perform conjugate gradient (CG) optimization for image reconstruction. - - The image is updated using the gradient of a data consistency term, - and a velocity vector is used to accelerate convergence. - - Parameters - ---------- - kspace_data : numpy.ndarray - The k-space data to be used for image reconstruction. - - x_init : numpy.ndarray, optional - An initial guess for the image. If None, an image of zeros with the same - shape as the expected output is used. Default is None. - - num_iter : int, optional - The maximum number of iterations to perform. Default is 10. - - tol : float, optional - The tolerance for convergence. If the norm of the gradient falls below this - value or the dot product between the image and k-space data is non-positive, - the iterations stop. Default is 1e-4. - - Returns - ------- - image : numpy.ndarray - The reconstructed image after the optimization process. - """ - Lipschitz_cst = self.get_lipschitz_cst() - image = ( - np.zeros(self.shape, dtype=type(kspace_data[0])) - if x_init is None - else x_init - ) - velocity = np.zeros_like(image) - - for _ in range(num_iter): - grad = self.data_consistency(image, kspace_data) - velocity = tol * velocity + grad * Lipschitz_cst - - if np.linalg.norm(grad) < tol: - break - image = image - velocity - return image - @property def uses_sense(self): """Return True if the operator uses sensitivity maps.""" diff --git a/tests/test_cg.py b/tests/test_cg.py index b73b671d..d3bf47d6 100644 --- a/tests/test_cg.py +++ b/tests/test_cg.py @@ -3,6 +3,7 @@ import numpy as np import pytest from pytest_cases import parametrize_with_cases, parametrize, fixture +from mrinufft.extras.gradient import cg from mrinufft import get_operator from case_trajectories import CasesTrajectories @@ -43,13 +44,13 @@ def test_cg(operator, array_interface, image_data): """Compare the interface to the raw NUDFT implementation.""" kspace_nufft = operator.op(image_data).squeeze() - image_cg = operator.cg(kspace_nufft) + image_cg = cg(operator,kspace_nufft) kspace_cg = operator.op(image_cg).squeeze() - + assert_almost_allclose( kspace_nufft, kspace_cg, atol=1e-1, - rtol=1e-1, + rtol=5e-1, mismatch=20, ) From 476c154084d1e575f46b3101eb2fba34f92d17e3 Mon Sep 17 00:00:00 2001 From: Lena OUDJMAN Date: Tue, 17 Sep 2024 09:19:26 +0200 Subject: [PATCH 07/11] styles --- src/mrinufft/extras/gradient.py | 92 +++++++++++++++++---------------- tests/test_cg.py | 4 +- 2 files changed, 49 insertions(+), 47 deletions(-) diff --git a/src/mrinufft/extras/gradient.py b/src/mrinufft/extras/gradient.py index f435095b..a7be8bdd 100644 --- a/src/mrinufft/extras/gradient.py +++ b/src/mrinufft/extras/gradient.py @@ -1,51 +1,53 @@ +"""Conjugate gradient optimization algorithm for image reconstruction.""" + import numpy as np from mrinufft.operators.base import with_numpy + @with_numpy def cg(operator, kspace_data, x_init=None, num_iter=10, tol=1e-4): - """ - Perform conjugate gradient (CG) optimization for image reconstruction. - - The image is updated using the gradient of a data consistency term, - and a velocity vector is used to accelerate convergence. - - Parameters - ---------- - kspace_data : numpy.ndarray - The k-space data to be used for image reconstruction. - - x_init : numpy.ndarray, optional - An initial guess for the image. If None, an image of zeros with the same - shape as the expected output is used. Default is None. - - num_iter : int, optional - The maximum number of iterations to perform. Default is 10. - - tol : float, optional - The tolerance for convergence. If the norm of the gradient falls below this - value or the dot product between the image and k-space data is non-positive, - the iterations stop. Default is 1e-4. - - Returns - ------- - image : numpy.ndarray - The reconstructed image after the optimization process. - """ - - Lipschitz_cst = operator.get_lipschitz_cst() - image = ( - np.zeros(operator.shape, dtype=type(kspace_data[0])) - if x_init is None - else x_init - ) - velocity = np.zeros_like(image) - - for _ in range(num_iter): - grad = operator.data_consistency(image, kspace_data) - velocity = tol * velocity + grad * Lipschitz_cst - if np.linalg.norm(grad) < tol: - break - image = image - velocity - - return image \ No newline at end of file + """ + Perform conjugate gradient (CG) optimization for image reconstruction. + + The image is updated using the gradient of a data consistency term, + and a velocity vector is used to accelerate convergence. + + Parameters + ---------- + kspace_data : numpy.ndarray + The k-space data to be used for image reconstruction. + + x_init : numpy.ndarray, optional + An initial guess for the image. If None, an image of zeros with the same + shape as the expected output is used. Default is None. + + num_iter : int, optional + The maximum number of iterations to perform. Default is 10. + + tol : float, optional + The tolerance for convergence. If the norm of the gradient falls below + this value or the dot product between the image and k-space data is + non-positive, the iterations stop. Default is 1e-4. + + Returns + ------- + image : numpy.ndarray + The reconstructed image after the optimization process. + """ + Lipschitz_cst = operator.get_lipschitz_cst() + image = ( + np.zeros(operator.shape, dtype=type(kspace_data[0])) + if x_init is None + else x_init + ) + velocity = np.zeros_like(image) + + for _ in range(num_iter): + grad = operator.data_consistency(image, kspace_data) + velocity = tol * velocity + grad * Lipschitz_cst + if np.linalg.norm(grad) < tol: + break + image = image - velocity + + return image diff --git a/tests/test_cg.py b/tests/test_cg.py index d3bf47d6..8582490a 100644 --- a/tests/test_cg.py +++ b/tests/test_cg.py @@ -44,9 +44,9 @@ def test_cg(operator, array_interface, image_data): """Compare the interface to the raw NUDFT implementation.""" kspace_nufft = operator.op(image_data).squeeze() - image_cg = cg(operator,kspace_nufft) + image_cg = cg(operator, kspace_nufft) kspace_cg = operator.op(image_cg).squeeze() - + assert_almost_allclose( kspace_nufft, kspace_cg, From 233dee1b382c925e021718fa4857622898afab31 Mon Sep 17 00:00:00 2001 From: Lena OUDJMAN Date: Wed, 18 Sep 2024 14:36:13 +0200 Subject: [PATCH 08/11] fix cg and test, improve example --- examples/example_cg.py | 35 +++++++++++++++++++++++++++++++++ src/mrinufft/extras/gradient.py | 17 ++++++++++++---- tests/test_cg.py | 28 +++++++++++++++++++++----- 3 files changed, 71 insertions(+), 9 deletions(-) create mode 100644 examples/example_cg.py diff --git a/examples/example_cg.py b/examples/example_cg.py new file mode 100644 index 00000000..a1c2c745 --- /dev/null +++ b/examples/example_cg.py @@ -0,0 +1,35 @@ +"""Example of using the Conjugate Gradient method.""" + +import numpy as np +import mrinufft +from brainweb_dl import get_mri +from mrinufft.extras.gradient import cg +from mrinufft.density import voronoi +from matplotlib import pyplot as plt +from scipy.datasets import face + +samples_loc = mrinufft.initialize_2D_radial(Nc=64, Ns=172) +image = get_mri(sub_id=4) +image = np.flipud(image[90]) + +NufftOperator = mrinufft.get_operator("gpunufft") # get the operator +density = voronoi(samples_loc) # get the density + +nufft = NufftOperator( + samples_loc, shape=image.shape, density=density, n_coils=1 +) # create the NUFFT operator + +kspace_data = nufft.op(image) # get the k-space data +reconstructed_image = cg(nufft, kspace_data) # reconstruct the image + + +# Display the results +plt.figure(figsize=(10, 5)) +plt.subplot(1, 2, 1) +plt.title("Original Image") +plt.imshow(abs(image), cmap="gray") + +plt.subplot(1, 2, 2) +plt.title("Reconstructed Image") +plt.imshow(abs(reconstructed_image), cmap="gray") +plt.show() diff --git a/src/mrinufft/extras/gradient.py b/src/mrinufft/extras/gradient.py index a7be8bdd..7d18e76d 100644 --- a/src/mrinufft/extras/gradient.py +++ b/src/mrinufft/extras/gradient.py @@ -43,11 +43,20 @@ def cg(operator, kspace_data, x_init=None, num_iter=10, tol=1e-4): ) velocity = np.zeros_like(image) + grad = operator.data_consistency(image, kspace_data) + velocity = tol * velocity + grad / Lipschitz_cst + image = image - velocity + for _ in range(num_iter): - grad = operator.data_consistency(image, kspace_data) - velocity = tol * velocity + grad * Lipschitz_cst - if np.linalg.norm(grad) < tol: + grad_new = operator.data_consistency(image, kspace_data) + if np.linalg.norm(grad_new) <= tol: break - image = image - velocity + + beta = np.dot(grad_new.flatten(), grad_new.flatten()) / np.dot( + grad.flatten(), grad.flatten() + ) + velocity = grad_new + beta * velocity + + image = image - velocity / Lipschitz_cst return image diff --git a/tests/test_cg.py b/tests/test_cg.py index 8582490a..888edf2d 100644 --- a/tests/test_cg.py +++ b/tests/test_cg.py @@ -17,9 +17,27 @@ @fixture(scope="module") @parametrize( "backend", - ["finufft", "torchkbnufft-cpu"], + [ + "bart", + "pynfft", + "pynufft-cpu", + "finufft", + "cufinufft", + "gpunufft", + "sigpy", + "torchkbnufft-cpu", + "torchkbnufft-gpu", + "tensorflow", + ], +) +@parametrize_with_cases( + "kspace_locs, shape", + cases=[ + CasesTrajectories.case_random2D, + CasesTrajectories.case_grid2D, + CasesTrajectories.case_grid3D, + ], ) -@parametrize_with_cases("kspace_locs, shape", cases=CasesTrajectories) def operator( request, backend="pynfft", @@ -48,9 +66,9 @@ def test_cg(operator, array_interface, image_data): kspace_cg = operator.op(image_cg).squeeze() assert_almost_allclose( - kspace_nufft, kspace_cg, - atol=1e-1, - rtol=5e-1, + kspace_nufft, + atol=5e-1, + rtol=1e-1, mismatch=20, ) From 2c1c3645057796052d756da3d827b5cfbf328ae0 Mon Sep 17 00:00:00 2001 From: Lena OUDJMAN Date: Mon, 23 Sep 2024 13:41:22 +0200 Subject: [PATCH 09/11] minore --- tests/test_cg.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_cg.py b/tests/test_cg.py index 888edf2d..2f8e7ac8 100644 --- a/tests/test_cg.py +++ b/tests/test_cg.py @@ -19,8 +19,6 @@ "backend", [ "bart", - "pynfft", - "pynufft-cpu", "finufft", "cufinufft", "gpunufft", From 6c20a04f30daff6f044712aed64a6c7018568c30 Mon Sep 17 00:00:00 2001 From: Lena OUDJMAN Date: Tue, 24 Sep 2024 09:41:37 +0200 Subject: [PATCH 10/11] more precise --- tests/test_cg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_cg.py b/tests/test_cg.py index 2f8e7ac8..d0806c99 100644 --- a/tests/test_cg.py +++ b/tests/test_cg.py @@ -66,7 +66,7 @@ def test_cg(operator, array_interface, image_data): assert_almost_allclose( kspace_cg, kspace_nufft, - atol=5e-1, + atol=2e-1, rtol=1e-1, mismatch=20, ) From 16ffa2532e5116d99414956c22dca48ad6f07253 Mon Sep 17 00:00:00 2001 From: Lena OUDJMAN Date: Tue, 24 Sep 2024 14:11:37 +0200 Subject: [PATCH 11/11] add explanation, and adj image, fix into jupyter --- examples/example_cg.py | 47 +++++++++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 7 deletions(-) diff --git a/examples/example_cg.py b/examples/example_cg.py index a1c2c745..443606df 100644 --- a/examples/example_cg.py +++ b/examples/example_cg.py @@ -1,17 +1,43 @@ -"""Example of using the Conjugate Gradient method.""" +# %% +""" +Example of using the Conjugate Gradient method. +This script demonstrates the use of the Conjugate Gradient (CG) method +for solving systems of linear equations of the form Ax = b, where A is a symmetric +positive-definite matrix. The CG method is an iterative algorithm that is particularly +useful for large, sparse systems where direct methods are computationally expensive. + +The Conjugate Gradient method is widely used in various scientific and engineering +applications, including solving partial differential equations, optimization problems, +and machine learning tasks. + +References +---------- +- Inpirations: + - https://sigpy.readthedocs.io/en/latest/_modules/sigpy/alg.html#ConjugateGradient + - https://aquaulb.github.io/book_solving_pde_mooc/solving_pde_mooc/notebooks/05_IterativeMethods/05_02_Conjugate_Gradient.html +- Wikipedia: + - https://en.wikipedia.org/wiki/Conjugate_gradient_method + - https://en.wikipedia.org/wiki/Momentum +""" + +# %% +# Imports import numpy as np import mrinufft from brainweb_dl import get_mri from mrinufft.extras.gradient import cg from mrinufft.density import voronoi from matplotlib import pyplot as plt -from scipy.datasets import face -samples_loc = mrinufft.initialize_2D_radial(Nc=64, Ns=172) +# %% +# Setup Inputs +samples_loc = mrinufft.initialize_2D_spiral(Nc=64, Ns=256) image = get_mri(sub_id=4) image = np.flipud(image[90]) +# %% +# Setup the NUFFT operator NufftOperator = mrinufft.get_operator("gpunufft") # get the operator density = voronoi(samples_loc) # get the density @@ -19,17 +45,24 @@ samples_loc, shape=image.shape, density=density, n_coils=1 ) # create the NUFFT operator +# %% +# Reconstruct the image using the CG method kspace_data = nufft.op(image) # get the k-space data reconstructed_image = cg(nufft, kspace_data) # reconstruct the image - +# %% # Display the results plt.figure(figsize=(10, 5)) -plt.subplot(1, 2, 1) +plt.subplot(1, 3, 1) plt.title("Original Image") plt.imshow(abs(image), cmap="gray") -plt.subplot(1, 2, 2) -plt.title("Reconstructed Image") +plt.subplot(1, 3, 2) +plt.title("Reconstructed Image with CG") plt.imshow(abs(reconstructed_image), cmap="gray") + +plt.subplot(1, 3, 3) +plt.title("Reconstructed Image with adjoint") +plt.imshow(abs(nufft.adj_op(kspace_data)), cmap="gray") + plt.show()