Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Conjugate Gradient with Momentum #185

Merged
merged 19 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions examples/example_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# %%
"""
Example of using the Conjugate Gradient method.

This script demonstrates the use of the Conjugate Gradient (CG) method
for solving systems of linear equations of the form Ax = b, where A is a symmetric
positive-definite matrix. The CG method is an iterative algorithm that is particularly
useful for large, sparse systems where direct methods are computationally expensive.

The Conjugate Gradient method is widely used in various scientific and engineering
applications, including solving partial differential equations, optimization problems,
and machine learning tasks.

References
----------
- Inpirations:
- https://sigpy.readthedocs.io/en/latest/_modules/sigpy/alg.html#ConjugateGradient
- https://aquaulb.github.io/book_solving_pde_mooc/solving_pde_mooc/notebooks/05_IterativeMethods/05_02_Conjugate_Gradient.html
- Wikipedia:
- https://en.wikipedia.org/wiki/Conjugate_gradient_method
- https://en.wikipedia.org/wiki/Momentum
"""

# %%
# Imports
import numpy as np
import mrinufft
from brainweb_dl import get_mri
from mrinufft.extras.gradient import cg
from mrinufft.density import voronoi
from matplotlib import pyplot as plt

# %%
# Setup Inputs
samples_loc = mrinufft.initialize_2D_spiral(Nc=64, Ns=256)
image = get_mri(sub_id=4)
image = np.flipud(image[90])

# %%
# Setup the NUFFT operator
NufftOperator = mrinufft.get_operator("gpunufft") # get the operator
density = voronoi(samples_loc) # get the density

nufft = NufftOperator(
samples_loc, shape=image.shape, density=density, n_coils=1
) # create the NUFFT operator

# %%
# Reconstruct the image using the CG method
kspace_data = nufft.op(image) # get the k-space data
reconstructed_image = cg(nufft, kspace_data) # reconstruct the image

# %%
# Display the results
plt.figure(figsize=(10, 5))
plt.subplot(1, 3, 1)
plt.title("Original Image")
plt.imshow(abs(image), cmap="gray")

Lenoush marked this conversation as resolved.
Show resolved Hide resolved
plt.subplot(1, 3, 2)
plt.title("Reconstructed Image with CG")
plt.imshow(abs(reconstructed_image), cmap="gray")

plt.subplot(1, 3, 3)
plt.title("Reconstructed Image with adjoint")
plt.imshow(abs(nufft.adj_op(kspace_data)), cmap="gray")

plt.show()
62 changes: 62 additions & 0 deletions src/mrinufft/extras/gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Conjugate gradient optimization algorithm for image reconstruction."""

import numpy as np

from mrinufft.operators.base import with_numpy


@with_numpy
def cg(operator, kspace_data, x_init=None, num_iter=10, tol=1e-4):
"""
Perform conjugate gradient (CG) optimization for image reconstruction.

The image is updated using the gradient of a data consistency term,
and a velocity vector is used to accelerate convergence.

Parameters
----------
kspace_data : numpy.ndarray
The k-space data to be used for image reconstruction.

x_init : numpy.ndarray, optional
An initial guess for the image. If None, an image of zeros with the same
shape as the expected output is used. Default is None.

num_iter : int, optional
The maximum number of iterations to perform. Default is 10.

tol : float, optional
The tolerance for convergence. If the norm of the gradient falls below
this value or the dot product between the image and k-space data is
non-positive, the iterations stop. Default is 1e-4.

Returns
-------
image : numpy.ndarray
The reconstructed image after the optimization process.
"""
Lipschitz_cst = operator.get_lipschitz_cst()
image = (
np.zeros(operator.shape, dtype=type(kspace_data[0]))
if x_init is None
else x_init
)
velocity = np.zeros_like(image)

grad = operator.data_consistency(image, kspace_data)
velocity = tol * velocity + grad / Lipschitz_cst
image = image - velocity

for _ in range(num_iter):
grad_new = operator.data_consistency(image, kspace_data)
if np.linalg.norm(grad_new) <= tol:
break

beta = np.dot(grad_new.flatten(), grad_new.flatten()) / np.dot(
grad.flatten(), grad.flatten()
)
velocity = grad_new + beta * velocity

image = image - velocity / Lipschitz_cst

return image
72 changes: 72 additions & 0 deletions tests/test_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Test for the cg function."""

import numpy as np
import pytest
from pytest_cases import parametrize_with_cases, parametrize, fixture
from mrinufft.extras.gradient import cg
from mrinufft import get_operator
from case_trajectories import CasesTrajectories

from helpers import (
image_from_op,
param_array_interface,
)
from helpers import assert_almost_allclose


@fixture(scope="module")
@parametrize(
"backend",
[
"bart",
"finufft",
"cufinufft",
"gpunufft",
"sigpy",
"torchkbnufft-cpu",
"torchkbnufft-gpu",
"tensorflow",
],
)
@parametrize_with_cases(
"kspace_locs, shape",
cases=[
CasesTrajectories.case_random2D,
CasesTrajectories.case_grid2D,
CasesTrajectories.case_grid3D,
],
)
def operator(
request,
backend="pynfft",
kspace_locs=None,
shape=None,
n_coils=1,
):
"""Generate an operator."""
if backend in ["pynfft", "sigpy"] and kspace_locs.shape[-1] == 3:
pytest.skip("3D for slow cpu is not tested")
return get_operator(backend)(kspace_locs, shape, n_coils=n_coils, smaps=None)


@fixture(scope="module")
def image_data(operator):
"""Generate a random image. Remains constant for the module."""
return image_from_op(operator)


@param_array_interface
def test_cg(operator, array_interface, image_data):
"""Compare the interface to the raw NUDFT implementation."""
kspace_nufft = operator.op(image_data).squeeze()

image_cg = cg(operator, kspace_nufft)
kspace_cg = operator.op(image_cg).squeeze()

assert_almost_allclose(
kspace_cg,
kspace_nufft,
atol=2e-1,
rtol=1e-1,
mismatch=20,
)
Loading