Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Conjugate Gradient with Momentum #185

Merged
merged 19 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions examples/example_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""Example of using the Conjugate Gradient method."""

import numpy as np
import mrinufft
from brainweb_dl import get_mri
from mrinufft.extras.gradient import cg
from mrinufft.density import voronoi
from matplotlib import pyplot as plt
from scipy.datasets import face

samples_loc = mrinufft.initialize_2D_radial(Nc=64, Ns=172)
image = get_mri(sub_id=4)
image = np.flipud(image[90])

NufftOperator = mrinufft.get_operator("gpunufft") # get the operator
density = voronoi(samples_loc) # get the density

nufft = NufftOperator(
samples_loc, shape=image.shape, density=density, n_coils=1
) # create the NUFFT operator

kspace_data = nufft.op(image) # get the k-space data
reconstructed_image = cg(nufft, kspace_data) # reconstruct the image


# Display the results
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(abs(image), cmap="gray")

Lenoush marked this conversation as resolved.
Show resolved Hide resolved
plt.subplot(1, 2, 2)
plt.title("Reconstructed Image")
plt.imshow(abs(reconstructed_image), cmap="gray")
plt.show()
62 changes: 62 additions & 0 deletions src/mrinufft/extras/gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Conjugate gradient optimization algorithm for image reconstruction."""

import numpy as np

from mrinufft.operators.base import with_numpy


@with_numpy
def cg(operator, kspace_data, x_init=None, num_iter=10, tol=1e-4):
"""
Perform conjugate gradient (CG) optimization for image reconstruction.

The image is updated using the gradient of a data consistency term,
and a velocity vector is used to accelerate convergence.

Parameters
----------
kspace_data : numpy.ndarray
The k-space data to be used for image reconstruction.

x_init : numpy.ndarray, optional
An initial guess for the image. If None, an image of zeros with the same
shape as the expected output is used. Default is None.

num_iter : int, optional
The maximum number of iterations to perform. Default is 10.

tol : float, optional
The tolerance for convergence. If the norm of the gradient falls below
this value or the dot product between the image and k-space data is
non-positive, the iterations stop. Default is 1e-4.

Returns
-------
image : numpy.ndarray
The reconstructed image after the optimization process.
"""
Lipschitz_cst = operator.get_lipschitz_cst()
image = (
np.zeros(operator.shape, dtype=type(kspace_data[0]))
if x_init is None
else x_init
)
velocity = np.zeros_like(image)

grad = operator.data_consistency(image, kspace_data)
velocity = tol * velocity + grad / Lipschitz_cst
image = image - velocity

for _ in range(num_iter):
grad_new = operator.data_consistency(image, kspace_data)
if np.linalg.norm(grad_new) <= tol:
break

beta = np.dot(grad_new.flatten(), grad_new.flatten()) / np.dot(
grad.flatten(), grad.flatten()
)
velocity = grad_new + beta * velocity

image = image - velocity / Lipschitz_cst

return image
72 changes: 72 additions & 0 deletions tests/test_cg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Test for the cg function."""

import numpy as np
import pytest
from pytest_cases import parametrize_with_cases, parametrize, fixture
from mrinufft.extras.gradient import cg
from mrinufft import get_operator
from case_trajectories import CasesTrajectories

from helpers import (
image_from_op,
param_array_interface,
)
from helpers import assert_almost_allclose


@fixture(scope="module")
@parametrize(
"backend",
[
"bart",
"finufft",
"cufinufft",
"gpunufft",
"sigpy",
"torchkbnufft-cpu",
"torchkbnufft-gpu",
"tensorflow",
],
)
@parametrize_with_cases(
"kspace_locs, shape",
cases=[
CasesTrajectories.case_random2D,
CasesTrajectories.case_grid2D,
CasesTrajectories.case_grid3D,
],
)
def operator(
request,
backend="pynfft",
kspace_locs=None,
shape=None,
n_coils=1,
):
"""Generate an operator."""
if backend in ["pynfft", "sigpy"] and kspace_locs.shape[-1] == 3:
pytest.skip("3D for slow cpu is not tested")
return get_operator(backend)(kspace_locs, shape, n_coils=n_coils, smaps=None)


@fixture(scope="module")
def image_data(operator):
"""Generate a random image. Remains constant for the module."""
return image_from_op(operator)


@param_array_interface
def test_cg(operator, array_interface, image_data):
"""Compare the interface to the raw NUDFT implementation."""
kspace_nufft = operator.op(image_data).squeeze()

image_cg = cg(operator, kspace_nufft)
kspace_cg = operator.op(image_cg).squeeze()

assert_almost_allclose(
kspace_cg,
kspace_nufft,
atol=2e-1,
rtol=1e-1,
mismatch=20,
)
Loading