From 6d7d0573bc79433713aa3f1802117056a4feca3d Mon Sep 17 00:00:00 2001 From: Guillaume DAVAL-FREROT Date: Fri, 20 Sep 2024 19:55:02 +0200 Subject: [PATCH 1/4] Clean CPU examples --- examples/README.rst | 2 +- examples/conftest.py | 2 +- examples/example_2D_trajectories.py | 6 +- examples/example_3D_trajectories.py | 9 +- examples/example_display_config.py | 59 +++---- examples/example_gif_2D.py | 24 ++- examples/example_gif_3D.py | 21 ++- examples/example_learn_samples_multires.py | 190 +++++++++++++-------- examples/example_offresonance.py | 140 +++++++++------ examples/example_readme.py | 25 +-- examples/example_stacked.py | 63 ++++--- examples/example_trajectory_tools.py | 4 +- examples/utils.py | 5 +- 13 files changed, 334 insertions(+), 216 deletions(-) diff --git a/examples/README.rst b/examples/README.rst index aaa7ac77..52b101a0 100644 --- a/examples/README.rst +++ b/examples/README.rst @@ -3,4 +3,4 @@ Examples ======== -This is a collection of examples showing how to use mri-nufft to perform MR image reconstruction. +This is a collection of examples showing how to use MRI-nufft to perform MR image reconstruction. diff --git a/examples/conftest.py b/examples/conftest.py index d18499ed..85a3e688 100644 --- a/examples/conftest.py +++ b/examples/conftest.py @@ -12,8 +12,8 @@ """ -import sys import runpy +import sys from pathlib import Path import matplotlib as mpl diff --git a/examples/example_2D_trajectories.py b/examples/example_2D_trajectories.py index de4124ef..ed50a9f7 100644 --- a/examples/example_2D_trajectories.py +++ b/examples/example_2D_trajectories.py @@ -13,20 +13,16 @@ # are redundant across the different patterns, some of the documentation # will refer to previous patterns for explanation. # -# Note that most sources have not been added yet, but will be in the near -# future. -# # External import matplotlib.pyplot as plt import numpy as np +from utils import show_argument, show_trajectory # Internal import mrinufft as mn import mrinufft.trajectories.maths as mntm from mrinufft import display_2D_trajectory -from utils import show_argument, show_trajectory - # %% # Script options diff --git a/examples/example_3D_trajectories.py b/examples/example_3D_trajectories.py index 7a5c134c..21bab047 100644 --- a/examples/example_3D_trajectories.py +++ b/examples/example_3D_trajectories.py @@ -13,10 +13,10 @@ # are redundant across the different patterns, some of the documentation # will refer to previous patterns for explanation. # -# Note that most sources have not been added yet, but will be in the near -# future. Also the examples hereafter only cover natively 3D trajectories +# Note that the examples hereafter only cover natively 3D trajectories # or famous 3D trajectories obtained from 2D. Examples on how to use -# 2D-to-3D expansion methods will be presented over another page. +# tools to make 3D trajectories out of 2D ones are presented in +# :ref:`sphx_glr_generated_autoexamples_example_trajectory_tools.py` # # In this page in particular, we invite the user to manually run the script # to be able to manipulate the plot orientations with the matplotlib interface @@ -26,12 +26,11 @@ # External import matplotlib.pyplot as plt import numpy as np +from utils import show_argument, show_trajectory # Internal import mrinufft as mn from mrinufft import display_2D_trajectory, display_3D_trajectory -from utils import show_argument, show_trajectory - # %% # Script options diff --git a/examples/example_display_config.py b/examples/example_display_config.py index 1bf03679..8b91141a 100644 --- a/examples/example_display_config.py +++ b/examples/example_display_config.py @@ -3,20 +3,24 @@ Trajectory display configuration ================================ -The look of the display trajectories can be tweaked by using :py:class:`displayConfig` +An example to show how to customize trajectory displays. -You can tune these parameters to your own taste and needs. +The parameters presented here can be tuned to your own taste and needs +by using :py:class:`displayConfig`. """ import matplotlib as mpl import matplotlib.pyplot as plt - -# %% import numpy as np from mrinufft import display_2D_trajectory, display_3D_trajectory, displayConfig from mrinufft.trajectories import conify, initialize_2D_spiral +# %% +# Script options +# ============== +# These options are used in the examples below to define trajectories and display options. + # Trajectory parameters Nc = 120 # Number of shots Ns = 500 # Number of samples per shot @@ -26,7 +30,6 @@ subfigure_size = 6 # Figure size for subplots one_shot = -5 # Highlight one shot in particular - # %% @@ -47,28 +50,32 @@ def show_traj(traj, name, values, **kwargs): # %% # # Trajectory displays -# ==================== -# To show case the display parameters of trajectories, we will use the following trajectory -# The effect of trajectory parameter are explained in the :ref:`sphx_glr_generated_autoexamples_example_3D_trajectories.py` Example. +# =================== +# +# The following trajectory will be used to showcase the display parameters. +# The trajectory parameters are explained in the +# :ref:`sphx_glr_generated_autoexamples_example_3D_trajectories.py` example. traj = conify(initialize_2D_spiral(Nc // 6, Ns), nb_cones=6)[::-1] # %% # ``linewidth`` # ------------- -# The linewidth of the shot can be updated to have more or less empty space in the plot. +# The ``linewidth`` corresponds to the curve thickness, and can be changed +# to improve the shots visibility. show_traj(traj, "linewidth", [0.5, 2, 4]) # %% # ``palette`` # ----------- -# The ``palette`` parameter allows to change the color of the shots. +# The ``palette`` parameter allows you to change the color of the shots. show_traj(traj, "palette", ["tab10", "magma", "jet"]) # %% # ``one_shot_color`` # ------------------ -# The ``one_shot_color`` parameter allows to highlight one shot in particular. +# The ``one_shot_color`` parameter is used to highlight one shot in particular +# with a specified color. with displayConfig(palette="viridis"): show_traj( traj, "one_shot_color", ["tab:blue", "tab:orange", "tab:green"], one_shot=-5 @@ -77,10 +84,11 @@ def show_traj(traj, name, values, **kwargs): # %% # ``nb_colors`` # ------------- -# The ``nb_colors`` parameter allows to change the number of colors used to display the shots. - +# The ``nb_colors`` parameter allows you to change the number of colors used from the +# specified color palette to display the shots. show_traj(traj, "nb_colors", [1, 4, 10]) + # %% # Labels, titles and legends # ========================== @@ -88,38 +96,25 @@ def show_traj(traj, name, values, **kwargs): # %% # ``fontsize`` # ------------ -# The ``fontsize`` parameter allows to change the fontsize of the labels /title - +# The ``fontsize`` parameter changes the fontsize of the labels/titles. show_traj(traj, "fontsize", [12, 18, 24]) # %% # ``pointsize`` # ------------- -# To show the gradient constraint violation we can use the ``pointsize`` parameter +# The ``pointsize`` parameter is used when showing the gradient constraint violations +# to change the violation point sizes. show_traj(traj, "pointsize", [0.5, 2, 4], show_constraints=True) # %% # ``gradient_point_color`` and ``slewrate_point_color`` # ----------------------------------------------------- -# The ``gradient_point_color`` and ``slewrate_point_color`` parameters allows to change the color of the points -# that are violating the gradient or slewrate constraints. - +# The ``gradient_point_color`` and ``slewrate_point_color`` parameters allows you +# to change the color of the points where gradient or slew rate constraint violations +# are observed. show_traj( traj, "slewrate_point_color", ["tab:blue", "tab:orange", "tab:red"], show_constraints=True, ) - - -# %% -# Gradients profiles -# ================== - -# %% - -# %% - -# %% - -# %% diff --git a/examples/example_gif_2D.py b/examples/example_gif_2D.py index 47238b61..96673cd3 100644 --- a/examples/example_gif_2D.py +++ b/examples/example_gif_2D.py @@ -1,22 +1,26 @@ """ -======================= -2D Trajectories display -======================= +======================== +Animated 2D trajectories +======================== -A collection of 2D trajectories are generated and saved as a gif. +An animation to show 2D trajectory customization. """ +import time + import joblib import matplotlib.pyplot as plt import numpy as np from PIL import Image, ImageSequence -import time + import mrinufft.trajectories.display as mtd import mrinufft.trajectories.trajectory2D as mtt from mrinufft.trajectories.display import displayConfig -# Options +# %% +# Script options +# ============== Nc = 16 Ns = 200 @@ -29,7 +33,9 @@ duration = 150 # seconds -# Generation +# %% +# Trajectory generation +# ===================== # Initialize trajectory function functions = [ @@ -125,6 +131,10 @@ ] +# %% +# Animation rendering +# =================== + frame_setup = [ (f, i, name, arg) for (name, f), args in list(zip(functions, arguments)) diff --git a/examples/example_gif_3D.py b/examples/example_gif_3D.py index f22d6bd6..80704e20 100644 --- a/examples/example_gif_3D.py +++ b/examples/example_gif_3D.py @@ -1,13 +1,14 @@ """ -======================= -3D Trajectories display -======================= +======================== +Animated 3D trajectories +======================== -A collection of 3D trajectories are generated and saved as a gif. +An animation to show 3D trajectory customization. """ import time + import joblib import matplotlib.pyplot as plt import numpy as np @@ -17,7 +18,9 @@ import mrinufft.trajectories.trajectory3D as mtt from mrinufft.trajectories.display import displayConfig -# Options +# %% +# Script options +# ============== Nc = 8 * 8 Ns = 200 @@ -30,7 +33,9 @@ duration = 150 # seconds -# Generation +# %% +# Trajectory generation +# ===================== # Initialize trajectory function functions = [ @@ -134,6 +139,10 @@ ] +# %% +# Animation rendering +# =================== + frame_setup = [ (f, i, name, arg) for (name, f), args in list(zip(functions, arguments)) diff --git a/examples/example_learn_samples_multires.py b/examples/example_learn_samples_multires.py index 3c124a1d..b0895388 100644 --- a/examples/example_learn_samples_multires.py +++ b/examples/example_learn_samples_multires.py @@ -1,27 +1,37 @@ # %% """ -=============================================== -Learn Sampling pattern with multi-resolution -=============================================== +========================================= +Learning sampling pattern with decimation +========================================= -A small pytorch example to showcase learning k-space sampling patterns. -This example showcases the auto-diff capabilities of the NUFFT operator -wrt to k-space trajectory in mri-nufft. +An example using PyTorch to showcase learning k-space sampling patterns with decimation. -In this example we learn the k-space samples :math:`\mathbf{K}` for the following cost function: +This example showcases the auto-differentiation capabilities of the NUFFT operator +with respect to the k-space trajectory in MRI-nufft. + +Hereafter we learn the k-space sample locations :math:`\mathbf{K}` using the following cost function: .. math:: \mathbf{\hat{K}} = arg \min_{\mathbf{K}} || \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} \mathbf{x} - \mathbf{x} ||_2^2 -where :math:`\mathcal{F}_\mathbf{K}` is the forward NUFFT operator and :math:`D_\mathbf{K}` is the density compensators for trajectory :math:`\mathbf{K}`, :math:`\mathbf{x}` is the MR image which is also the target image to be reconstructed. +where :math:`\mathcal{F}_\mathbf{K}` is the forward NUFFT operator, +:math:`D_\mathbf{K}` is the density compensator for trajectory :math:`\mathbf{K}`, +and :math:`\mathbf{x}` is the MR image which is also the target image to be reconstructed. -Additionally, in-order to converge faster, we also learn the trajectory in a multi-resolution fashion. This is done by first optimizing a 8 times decimated trajectory locations, called control points. After a fixed number of iterations (5 in this example), these control points are upscaled by a factor of 2. However, note that the NUFFT operator always holds linearly interpolated version of the control points as k-space sampling trajectory. +Additionally, in order to converge faster, we also learn the trajectory in a multi-resolution fashion. +This is done by first optimizing x8 times decimated trajectory locations, called control points. +After a fixed number of iterations (5 in this example), these control points are upscaled by a factor of 2. +Note that the NUFFT operator always holds linearly interpolated version of the control points as k-space sampling trajectory. .. note:: This example can run on a binder instance as it is purely CPU based backend (finufft), and is restricted to a 2D single coil toy case. .. warning:: - This example only showcases the autodiff capabilities, the learned sampling pattern is not scanner compliant as the scanner gradients required to implement it violate the hardware constraints. In practice, a projection :math:`\Pi_\mathcal{Q}(\mathbf{K})` into the scanner constraints set :math:`\mathcal{Q}` is recommended (see [Proj]_). This is implemented in the proprietary SPARKLING package [Sparks]_. Users are encouraged to contact the authors if they want to use it. + This example only showcases the auto-differentiation capabilities, the learned sampling pattern + is not scanner compliant as the gradients required to implement it violate the hardware constraints. + In practice, a projection :math:`\Pi_\mathcal{Q}(\mathbf{K})` onto the scanner constraints set :math:`\mathcal{Q}` is recommended + (see [Cha+16]_). This is implemented in the proprietary SPARKLING package [Cha+22]_. + Users are encouraged to contact the authors if they want to use it. """ # %% # .. colab-link:: @@ -29,29 +39,28 @@ # # !pip install mri-nufft[finufft] -# %% -# Imports -# ------- - import time -import joblib import brainweb_dl as bwdl +import joblib import matplotlib.pyplot as plt import numpy as np import torch -from tqdm import tqdm from PIL import Image, ImageSequence +from tqdm import tqdm from mrinufft import get_operator from mrinufft.trajectories import initialize_2D_radial # %% -# Setup a simple class to learn trajectory -# ---------------------------------------- +# Utils +# ===== +# +# Model class +# ----------- # .. note:: # While we are only learning the NUFFT operator, we still need the gradient `wrt_data=True` to have all the gradients computed correctly. -# See [Projector]_ for more details. +# See [GRC23]_ for more details. class Model(torch.nn.Module): @@ -116,29 +125,44 @@ def forward(self, x): # %% -# Util function to plot the state of the model -# -------------------------------------------- +# State plotting +# -------------- -def plot_state( - axs, mri_2D, traj, recon, control_points=None, loss=None, save_name=None -): +def plot_state(axs, image, traj, recon, control_points=None, loss=None, save_name=None): axs = axs.flatten() - axs[0].imshow(np.abs(mri_2D[0]), cmap="gray") + # Upper left reference image + axs[0].imshow(np.abs(image[0]), cmap="gray") axs[0].axis("off") axs[0].set_title("MR Image") + + # Upper right trajectory axs[1].scatter(*traj.T, s=0.5) if control_points is not None: axs[1].scatter(*control_points.T, s=1, color="r") - axs[1].legend(["Trajectory", "Control Points"]) + axs[1].legend( + ["Trajectory", "Control points"], loc="right", bbox_to_anchor=(2, 0.6) + ) + axs[1].grid(True) axs[1].set_title("Trajectory") + axs[1].set_xlim(-0.5, 0.5) + axs[1].set_ylim(-0.5, 0.5) + axs[1].set_aspect("equal") + + # Down left reconstructed image axs[2].imshow(np.abs(recon[0][0].detach().cpu().numpy()), cmap="gray") axs[2].axis("off") axs[2].set_title("Reconstruction") + + # Down right loss evolution if loss is not None: axs[3].plot(loss) + axs[3].set_ylim(0, None) axs[3].grid("on") axs[3].set_title("Loss") + plt.subplots_adjust(hspace=0.3) + + # Save & close if save_name is not None: plt.savefig(save_name, bbox_inches="tight") plt.close() @@ -146,6 +170,11 @@ def plot_state( plt.show() +# %% +# Optimizer upscaling +# ------------------- + + def upsample_optimizer(optimizer, new_optimizer, factor=2): """Upsample the optimizer.""" for old_group, new_group in zip(optimizer.param_groups, new_optimizer.param_groups): @@ -172,47 +201,61 @@ def upsample_optimizer(optimizer, new_optimizer, factor=2): # %% -# Setup Inputs (models, trajectory and image) -# ------------------------------------------- -# First we create the model with a simple radial trajectory (32 shots of 256 points) +# Data preparation +# ================ +# +# A single image to train the model over. Note that in practice +# we would use a whole dataset instead (e.g. fastMRI). +# -init_traj = initialize_2D_radial(32, 256).astype(np.float32) -model = Model(init_traj, img_size=(256, 256)) -model.eval() +volume = np.flip(bwdl.get_mri(4, "T1"), axis=(0, 1, 2)) +image = torch.from_numpy(volume[-80, ...].astype(np.float32))[None] +image = image / torch.mean(image) # %% -# The image on which we are going to train. -# .. note :: -# In practice we would use instead a dataset (e.g. fastMRI) -# +# A basic radial trajectory with an acceleration factor of 8. -mri_2D = torch.from_numpy(np.flipud(bwdl.get_mri(4, "T1")[80, ...]).astype(np.float32))[ - None -] -mri_2D = mri_2D / torch.mean(mri_2D) +AF = 8 +initial_traj = initialize_2D_radial(image.shape[1] // AF, image.shape[2]).astype( + np.float32 +) +# %% +# Trajectory learning +# =================== +# # Initialisation # -------------- -# Before training, here is the simple reconstruction we have using a -# density compensated adjoint. -recon = model(mri_2D) -fig, axs = plt.subplots(1, 3, figsize=(15, 5)) -plot_state(axs, mri_2D, init_traj, recon, model.control.detach().cpu().numpy()) +model = Model(initial_traj, img_size=image.shape[1:]) +model = model.eval() # %% -# Start training loop -# ------------------- +# The image obtained before learning the sampling pattern +# is highly degraded because of the acceleration factor and simplicity +# of the trajectory. + +initial_recons = model(image) + +fig, axs = plt.subplots(1, 3, figsize=(9, 3)) +plot_state(axs, image, initial_traj, initial_recons) + + +# %% +# Training loop +# ------------- + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) +model.train() + losses = [] image_files = [] -model.train() while model.current_decim >= 1: with tqdm(range(30), unit="steps") as tqdms: for i in tqdms: - out = model(mri_2D) - loss = torch.nn.functional.mse_loss(out, mri_2D[None, None]) + out = model(image) + loss = torch.nn.functional.mse_loss(out, image[None, None]) numpy_loss = (loss.detach().cpu().numpy(),) tqdms.set_postfix({"loss": numpy_loss}) @@ -232,7 +275,7 @@ def upsample_optimizer(optimizer, new_optimizer, factor=2): fig, axs = plt.subplots(2, 2, figsize=(10, 10), num=1) plot_state( axs, - mri_2D, + image, model.get_trajectory().detach().cpu().numpy(), out, model.control.detach().cpu().numpy(), @@ -248,6 +291,7 @@ def upsample_optimizer(optimizer, new_optimizer, factor=2): optimizer, torch.optim.Adam(model.parameters(), lr=1e-3) ) +# %% # Make a GIF of all images. imgs = [Image.open(img) for img in image_files] @@ -259,6 +303,7 @@ def upsample_optimizer(optimizer, new_optimizer, factor=2): duration=2, loop=0, ) + # sphinx_gallery_start_ignore # cleanup import os @@ -284,8 +329,6 @@ def upsample_optimizer(optimizer, new_optimizer, factor=2): # sphinx_gallery_end_ignore -# sphinx_gallery_thumbnail_path = 'generated/autoexamples/images/mrinufft_learn_traj_multires.gif' - # %% # .. image-sg:: /generated/autoexamples/images/mrinufft_learn_traj_multires.gif # :alt: example learn_samples @@ -293,42 +336,43 @@ def upsample_optimizer(optimizer, new_optimizer, factor=2): # :class: sphx-glr-single-img # %% -# Trained trajectory -# ------------------ +# Results +# ------- + model.eval() -recon = model(mri_2D) -fig, axs = plt.subplots(2, 2, figsize=(10, 10)) -plot_state( - axs, - mri_2D, - model.get_trajectory().detach().cpu().numpy(), - recon=recon, - control_points=None, - loss=losses, -) +final_recons = model(image) +final_traj = model.get_trajectory().detach().cpu().numpy() + +# %% + +fig, axs = plt.subplots(1, 3, figsize=(9, 3)) +plot_state(axs, image, final_traj, final_recons) plt.show() # %% -# .. note:: -# The above learned trajectory is not that good because: -# - The trajectory is trained only for 5 iterations per decimation level, resulting in a suboptimal trajectory. -# - In order to make the example CPU compliant, we had to resort to preventing density compensation, hence the reconstructor is not good. # -# Users are requested to checkout :ref:`sphx_glr_generated_autoexamples_GPU_example_learn_samples.py` for example with density compensation. +# The learned trajectory above improves the reconstruction quality as compared to +# the initial trajectory shown above. Note of course that the reconstructed +# image is far from perfect because of the documentation rendering constraints. +# In order to improve the results one can start by training it for more than +# just 5 iterations per decimation level. Also density compensation should be used, +# even though it was avoided here for CPU compliance. Check out +# :ref:`sphx_glr_generated_autoexamples_GPU_example_learn_samples.py` to know more. + # %% # References # ========== # -# .. [Proj] N. Chauffert, P. Weiss, J. Kahn and P. Ciuciu, "A Projection Algorithm for +# .. [Cha+16] N. Chauffert, P. Weiss, J. Kahn and P. Ciuciu, "A Projection Algorithm for # Gradient Waveforms Design in Magnetic Resonance Imaging," in # IEEE Transactions on Medical Imaging, vol. 35, no. 9, pp. 2026-2039, Sept. 2016, # doi: 10.1109/TMI.2016.2544251. -# .. [Sparks] G. R. Chaithya, P. Weiss, G. Daval-Frérot, A. Massire, A. Vignaud and P. Ciuciu, +# .. [Cha+22] G. R. Chaithya, P. Weiss, G. Daval-Frérot, A. Massire, A. Vignaud and P. Ciuciu, # "Optimizing Full 3D SPARKLING Trajectories for High-Resolution Magnetic # Resonance Imaging," in IEEE Transactions on Medical Imaging, vol. 41, no. 8, # pp. 2105-2117, Aug. 2022, doi: 10.1109/TMI.2022.3157269. -# .. [Projector] Chaithya GR, and Philippe Ciuciu. 2023. "Jointly Learning Non-Cartesian +# .. [GRC23] Chaithya GR, and Philippe Ciuciu. 2023. "Jointly Learning Non-Cartesian # k-Space Trajectories and Reconstruction Networks for 2D and 3D MR Imaging # through Projection" Bioengineering 10, no. 2: 158. # https://doi.org/10.3390/bioengineering10020158 diff --git a/examples/example_offresonance.py b/examples/example_offresonance.py index a16c547c..761f9f66 100644 --- a/examples/example_offresonance.py +++ b/examples/example_offresonance.py @@ -1,13 +1,13 @@ """ -====================== -Off-resonance Corrected NUFFT Operator -====================== +====================================== +Off-resonance corrected NUFFT operator +====================================== -Example of Off-resonance Corrected NUFFT trajectory operator. +An example to show how to setup an off-resonance corrected NUFFT operator. -This examples show how to use the Off-resonance Corrected NUFFT operator to acquire -and reconstruct data in presence of field inhomogeneities. -Here a spiral trajectory is used as a demonstration. +This example shows how to use the off-resonance corrected (ORC) NUFFT operator +to reconstruct data in presence of B0 field inhomogeneities. +Hereafter a 2D spiral trajectory is used for demonstration. """ @@ -18,45 +18,70 @@ plt.rcParams["image.cmap"] = "gray" + # %% -# Data Generation -# =============== -# For realistic 2D image we will use a slice from the brainweb dataset. -# installable using ``pip install brainweb-dl`` +# Data preparation +# ================ +# +# Image loading +# ------------- +# +# For realistic a 2D image we will use the BrainWeb dataset, +# installable using ``pip install brainweb-dl``. from brainweb_dl import get_mri mri_data = get_mri(0, "T1") -mri_data = mri_data[::-1, ...][90] -plt.imshow(mri_data), plt.axis("off"), plt.title("ground truth") +mri_data = np.flip(mri_data, axis=(0, 1, 2))[90] + +# %% + +plt.imshow(mri_data) +plt.axis("off") +plt.title("Groundtruth") +plt.show() + # %% -# Masking -# =============== -# Here, we generate a binary mask to exclude the background. -# We perform a simple binary threshold; in real-world application, -# it is advised to use other tools (e.g., FSL-BET). +# Mask generation +# --------------- +# +# A binary mask is generated to exclude the background. +# We use a simple binary threshold for this example, but for real-world application +# it is advised to use more advanced methods and tools (e.g., FSL-BET). brain_mask = mri_data > 0.1 * mri_data.max() -plt.imshow(brain_mask), plt.axis("off"), plt.title("brain mask") # %% -# Field Generation -# =============== -# Here, we generate a radial B0 field with the same shape of -# the input Shepp-Logan phantom + +plt.imshow(brain_mask) +plt.axis("off") +plt.title("brain mask") +plt.show() + + +# %% +# B0 field map generation +# ----------------------- +# +# A dummy B0 field map is generated for this example using the input shape. from mrinufft.extras import make_b0map -# generate field b0map, _ = make_b0map(mri_data.shape, b0range=(-200, 200), mask=brain_mask) -plt.imshow(brain_mask * b0map, cmap="bwr", vmin=-200, vmax=200), plt.axis( - "off" -), plt.colorbar(), plt.title("B0 map [Hz]") # %% -# Generate a Spiral trajectory -# ---------------------------- + +plt.imshow(brain_mask * b0map, cmap="bwr", vmin=-200, vmax=200) +plt.axis("off") +plt.colorbar() +plt.title("B0 map [Hz]") +plt.show() + + +# %% +# Trajectory generation +# --------------------- from mrinufft import initialize_2D_spiral from mrinufft.density import voronoi @@ -67,44 +92,61 @@ t_read = np.repeat(t_read[None, ...], samples.shape[0], axis=0) density = voronoi(samples) +# %% + display_2D_trajectory(samples) +plt.show() # %% -# Setup the Operator -# ================== +# Operator setup +# ============== from mrinufft import get_operator from mrinufft.operators.off_resonance import MRIFourierCorrected # Generate standard NUFFT operator nufft = get_operator("finufft")( - samples=samples, + samples=2 * np.pi * samples, # normalize for finufft shape=mri_data.shape, density=density, ) -# Generate Fourier Corrected operator -mfi_nufft = MRIFourierCorrected( +# Generate NUFFT off-resonance corrected operator +orc_nufft = MRIFourierCorrected( nufft, b0_map=b0map, readout_time=t_read, mask=brain_mask ) -# Generate K-Space -kspace = mfi_nufft.op(mri_data) - -# Reconstruct without field correction -mri_data_adj = nufft.adj_op(kspace) -mri_data_adj = np.squeeze(abs(mri_data_adj)) +# Generate k-space +kspace_on = nufft.op(mri_data) +kspace_off = orc_nufft.op(mri_data) -# Reconstruct with field correction -mri_data_adj_mfi = mfi_nufft.adj_op(kspace) -mri_data_adj_mfi = np.squeeze(abs(mri_data_adj_mfi)) +# Reconstruct without B0 field inhomogeneity +mri_data_adj_ref = nufft.adj_op(kspace_on) +mri_data_adj_ref = np.squeeze(abs(mri_data_adj_ref)) -fig2, ax2 = plt.subplots(1, 2) -ax2[0].imshow(mri_data_adj), ax2[0].axis("off"), ax2[0].set_title("w/o correction") -ax2[1].imshow(mri_data_adj_mfi), ax2[1].axis("off"), ax2[1].set_title("with correction") +# Reconstruct without B0 field correction +mri_data_adj = nufft.adj_op(kspace_off) +mri_data_adj = np.squeeze(abs(mri_data_adj)) -plt.show() +# Reconstruct with B0 field correction +mri_data_adj_orc = orc_nufft.adj_op(kspace_off) +mri_data_adj_orc = np.squeeze(abs(mri_data_adj_orc)) # %% -# The blurring is significantly reduced using the Off-resonance Corrected -# operator (right) +# The blurring observed in the presence of B0 field inhomogeneities (middle) +# is significantly reduced using the off-resonance corrected NUFFT operator (right). + +fig2, ax2 = plt.subplots(1, 3, figsize=(9, 3)) +# No off-resonance +ax2[0].imshow(mri_data_adj_ref) +ax2[0].axis("off") +ax2[0].set_title("No off-resonance") +# No off-resonance correction +ax2[1].imshow(mri_data_adj) +ax2[1].axis("off") +ax2[1].set_title("Off-resonance") +# Off-resonance corrected +ax2[2].imshow(mri_data_adj_orc) +ax2[2].axis("off") +ax2[2].set_title("Corrected off-resonance") +plt.show() diff --git a/examples/example_readme.py b/examples/example_readme.py index d1341f26..6be6d60a 100644 --- a/examples/example_readme.py +++ b/examples/example_readme.py @@ -1,8 +1,8 @@ """ -Minimal Example script +Minimal example script ====================== -This script shows how to use the package to perform a simple NUFFT. +An example to show how to perform a simple NUFFT. """ import matplotlib.pyplot as plt @@ -13,7 +13,7 @@ from mrinufft.density import voronoi from mrinufft.trajectories import display -# Create a 2D Radial trajectory for demo +# Create a 2D radial trajectory for demo samples_loc = mrinufft.initialize_2D_radial(Nc=100, Ns=500) # Get a 2D image for the demo (512x512) image = np.complex64(face(gray=True)[256:768, 256:768]) @@ -25,27 +25,32 @@ # For better image quality we use a density compensation density = voronoi(samples_loc) -# And create the associated operator. +# And create the associated operator nufft = NufftOperator( samples_loc, shape=image.shape, density=density, n_coils=1, squeeze_dims=True ) -kspace_data = nufft.op(image) # Image -> Kspace -image2 = nufft.adj_op(kspace_data) # Kspace -> Image +kspace_data = nufft.op(image) # Image -> K-space +image2 = nufft.adj_op(kspace_data) # K-space -> Image + +# %% # Show the results fig, ax = plt.subplots(2, 2) ax = ax.flatten() - +# Upper left reference image ax[0].imshow(abs(image), cmap="gray") ax[0].axis("off") ax[0].set_title("original image") +# Upper right trajectory display.display_2D_trajectory(samples_loc, subfigure=ax[1]) ax[1].set_aspect("equal") ax[1].set_title("Sampled points in k-space") +# Bottom left reconstructed image ax[2].imshow(abs(image2), cmap="gray") ax[2].axis("off") ax[2].set_title("Auto adjoint image") +# Bottom right error ax[3].imshow( abs(image2) / np.max(abs(image2)) - abs(image) / np.max(abs(image)), cmap="gray" ) @@ -57,7 +62,7 @@ # %% # .. note:: -# This image is not the same as the original one because the NUFFT operator -# is not a perfect adjoint, and we undersampled by a factor of 5. -# The artefact of reconstruction can be remove by using an iterative reconstruction method. +# This resulting image is not the same as the original one because the NUFFT operator +# is not a perfect inverse operation but an adjoint, and we undersampled by a factor of 5. +# The reconstruction artifacts can be removed by using an iterative reconstruction method. # Check PySAP-mri documentation for examples. diff --git a/examples/example_stacked.py b/examples/example_stacked.py index b70a31f8..10af7e04 100644 --- a/examples/example_stacked.py +++ b/examples/example_stacked.py @@ -1,13 +1,13 @@ """ ====================== -Stacked NUFFT Operator +Stacked NUFFT operator ====================== -Example of Stacked NUFFT trajectory operator. +An example to show how to setup a stacked NUFFT operator. -This examples show how to use the Stacked NUFFT operator to acquire and reconstruct data -in kspace where the sampling of pattern is a stack of non cartesian trajectory. -Here a stack of spiral is used as a demonstration. +This example shows how to use the stacked NUFFT operator to reconstruct data +when the sampling pattern in k-space is a stack of 2D non-Cartesian trajectories. +Hereafter a stack of 2D spirals is used for demonstration. """ @@ -18,43 +18,61 @@ plt.rcParams["image.cmap"] = "gray" + # %% -# Data Generation -# =============== -# For realistic 3D images we will use the brainweb dataset. -# installable using ``pip install brainweb-dl`` +# Data preparation +# ================ +# +# Image loading +# ------------- +# +# For realistic 3D images we will use the BrainWeb dataset, +# installable using ``pip install brainweb-dl``. from brainweb_dl import get_mri mri_data = get_mri(0, "T1") -mri_data = mri_data[::-1, ...] -fig, ax = plt.subplots(1, 3) +mri_data = np.flip(mri_data, axis=(0, 1, 2)) + +# %% + +fig, ax = plt.subplots(1, 3, figsize=(10, 3)) ax[0].imshow(mri_data[90, :, :]) ax[1].imshow(mri_data[:, 108, :]) ax[2].imshow(mri_data[:, :, 90]) +plt.show() + # %% -# Generate a Spiral trajectory -# ---------------------------- +# Trajectory generation +# --------------------- +# +# Only the 2D pattern needs to be initialized, along with +# its density to improve the adjoint NUFFT operation and +# the location of the different slices. +# from mrinufft import initialize_2D_spiral from mrinufft.density import voronoi samples = initialize_2D_spiral(Nc=16, Ns=500, nb_revolutions=10) density = voronoi(samples) +kz_slices = np.arange(mri_data.shape[-1]) # Specify locations for the stacks. + +# %% display_2D_trajectory(samples) -# specify locations for the stack of trajectories. -kz_slices = np.arange(mri_data.shape[-1]) +plt.show() + # %% -# Setup the Operator -# ================== +# Operator setup +# ============== from mrinufft.operators.stacked import MRIStackedNUFFT stacked_nufft = MRIStackedNUFFT( - samples=samples, + samples=2 * np.pi * samples, # normalize for finufft shape=mri_data.shape, z_index=kz_slices, backend="finufft", @@ -64,15 +82,16 @@ ) kspace_stack = stacked_nufft.op(mri_data) -print(kspace_stack.shape) +print(f"K-space shape: {kspace_stack.shape}") mri_data_adj = stacked_nufft.adj_op(kspace_stack) mri_data_adj = np.squeeze(abs(mri_data_adj)) -print(mri_data_adj.shape) +print(f"Volume shape: {mri_data_adj.shape}") -fig2, ax2 = plt.subplots(1, 3) +# %% + +fig2, ax2 = plt.subplots(1, 3, figsize=(10, 3)) ax2[0].imshow(mri_data_adj[90, :, :]) ax2[1].imshow(mri_data_adj[:, 108, :]) ax2[2].imshow(mri_data_adj[:, :, 90]) - plt.show() diff --git a/examples/example_trajectory_tools.py b/examples/example_trajectory_tools.py index be33ed44..e0a2efff 100644 --- a/examples/example_trajectory_tools.py +++ b/examples/example_trajectory_tools.py @@ -24,14 +24,12 @@ # External import matplotlib.pyplot as plt import numpy as np +from utils import show_argument, show_trajectory # Internal import mrinufft as mn import mrinufft.trajectories.tools as tools - from mrinufft.trajectories.utils import KMAX -from utils import show_argument, show_trajectory - # %% # Script options diff --git a/examples/utils.py b/examples/utils.py index 4db98042..7c70595b 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -3,12 +3,13 @@ the examples. """ +import matplotlib.pyplot as plt + # External imports import numpy as np -import matplotlib.pyplot as plt # Internal imports -from mrinufft import displayConfig, display_2D_trajectory, display_3D_trajectory +from mrinufft import display_2D_trajectory, display_3D_trajectory, displayConfig def show_argument(function, arguments, one_shot, subfig_size, dim="3D", axes=(0, 1)): From d68692a85008ddb857542a707a82ba0065cb4098 Mon Sep 17 00:00:00 2001 From: Guillaume DAVAL-FREROT Date: Sat, 28 Sep 2024 09:22:36 +0200 Subject: [PATCH 2/4] Make examples sorted by title --- docs/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/conf.py b/docs/conf.py index dc91bc45..94050f9d 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -78,6 +78,7 @@ "reference_url": {"mrinufft": None}, "examples_dirs": ["../examples/"], "gallery_dirs": ["generated/autoexamples"], + "within_subsection_order": "ExampleTitleSortKey", "filename_pattern": "/example_", "ignore_pattern": r"(__init__|conftest|utils).py", "nested_sections": True, From d76c4431d6502d1d7ac9c7ae16d504b1d1761296 Mon Sep 17 00:00:00 2001 From: Guillaume DAVAL-FREROT Date: Sat, 28 Sep 2024 10:17:34 +0200 Subject: [PATCH 3/4] Clean new CG example and move it --- examples/{ => GPU}/example_cg.py | 38 ++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 16 deletions(-) rename examples/{ => GPU}/example_cg.py (75%) diff --git a/examples/example_cg.py b/examples/GPU/example_cg.py similarity index 75% rename from examples/example_cg.py rename to examples/GPU/example_cg.py index 443606df..717c293b 100644 --- a/examples/example_cg.py +++ b/examples/GPU/example_cg.py @@ -1,24 +1,27 @@ -# %% """ -Example of using the Conjugate Gradient method. +====================================== +Reconstruction with conjugate gradient +====================================== + +An example to show how to reconstruct volumes using conjugate gradient method. -This script demonstrates the use of the Conjugate Gradient (CG) method -for solving systems of linear equations of the form Ax = b, where A is a symmetric -positive-definite matrix. The CG method is an iterative algorithm that is particularly +This script demonstrates the use of the Conjugate Gradient (CG) method +for solving systems of linear equations of the form Ax = b, where A is a symmetric +positive-definite matrix. The CG method is an iterative algorithm that is particularly useful for large, sparse systems where direct methods are computationally expensive. -The Conjugate Gradient method is widely used in various scientific and engineering -applications, including solving partial differential equations, optimization problems, +The Conjugate Gradient method is widely used in various scientific and engineering +applications, including solving partial differential equations, optimization problems, and machine learning tasks. References ---------- -- Inpirations: +- Inpirations: - https://sigpy.readthedocs.io/en/latest/_modules/sigpy/alg.html#ConjugateGradient - https://aquaulb.github.io/book_solving_pde_mooc/solving_pde_mooc/notebooks/05_IterativeMethods/05_02_Conjugate_Gradient.html -- Wikipedia: - - https://en.wikipedia.org/wiki/Conjugate_gradient_method - - https://en.wikipedia.org/wiki/Momentum +- Wikipedia: + - https://en.wikipedia.org/wiki/Conjugate_gradient_method + - https://en.wikipedia.org/wiki/Momentum """ # %% @@ -42,7 +45,10 @@ density = voronoi(samples_loc) # get the density nufft = NufftOperator( - samples_loc, shape=image.shape, density=density, n_coils=1 + samples_loc, + shape=image.shape, + density=density, + n_coils=1, ) # create the NUFFT operator # %% @@ -52,17 +58,17 @@ # %% # Display the results -plt.figure(figsize=(10, 5)) +plt.figure(figsize=(9, 3)) plt.subplot(1, 3, 1) -plt.title("Original Image") +plt.title("Original image") plt.imshow(abs(image), cmap="gray") plt.subplot(1, 3, 2) -plt.title("Reconstructed Image with CG") +plt.title("Conjugate gradient") plt.imshow(abs(reconstructed_image), cmap="gray") plt.subplot(1, 3, 3) -plt.title("Reconstructed Image with adjoint") +plt.title("Adjoint NUFFT") plt.imshow(abs(nufft.adj_op(kspace_data)), cmap="gray") plt.show() From 8eb51781818fa706e601ce6f1188551c76f664f8 Mon Sep 17 00:00:00 2001 From: Guillaume DAVAL-FREROT Date: Sat, 28 Sep 2024 10:54:08 +0200 Subject: [PATCH 4/4] Use tempfile for gifs in examples --- examples/example_gif_2D.py | 6 +++--- examples/example_gif_3D.py | 6 +++--- examples/example_learn_samples_multires.py | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/example_gif_2D.py b/examples/example_gif_2D.py index 96673cd3..4ad96d0c 100644 --- a/examples/example_gif_2D.py +++ b/examples/example_gif_2D.py @@ -12,6 +12,7 @@ import joblib import matplotlib.pyplot as plt import numpy as np +import tempfile as tmp from PIL import Image, ImageSequence import mrinufft.trajectories.display as mtd @@ -142,7 +143,7 @@ ] -def draw_frame(func, index, name, arg, save_dir="/tmp/"): +def draw_frame(func, index, name, arg): """Draw a single frame of the gif and save it to a tmp file.""" trajectory = func(arg) # General configuration @@ -168,8 +169,7 @@ def draw_frame(func, index, name, arg, save_dir="/tmp/"): ) # Save figure - hashed = joblib.hash((index, name, arg, time.time())) - filename = save_dir + f"{hashed}.png" + filename = f"{tmp.NamedTemporaryFile().name}.png" plt.savefig(filename, bbox_inches="tight") plt.close() return filename diff --git a/examples/example_gif_3D.py b/examples/example_gif_3D.py index 80704e20..2b784777 100644 --- a/examples/example_gif_3D.py +++ b/examples/example_gif_3D.py @@ -12,6 +12,7 @@ import joblib import matplotlib.pyplot as plt import numpy as np +import tempfile as tmp from PIL import Image, ImageSequence import mrinufft.trajectories.display as mtd @@ -150,7 +151,7 @@ ] -def draw_frame(func, index, name, arg, save_dir="/tmp/"): +def draw_frame(func, index, name, arg): """Draw a single frame of the gif and save it to a tmp file.""" trajectory = func(arg) # General configuration @@ -173,8 +174,7 @@ def draw_frame(func, index, name, arg, save_dir="/tmp/"): ) # Save figure - hashed = joblib.hash((index, name, arg, time.time())) - filename = save_dir + f"{hashed}.png" + filename = f"{tmp.NamedTemporaryFile().name}.png" plt.savefig(filename, bbox_inches="tight") plt.close() return filename diff --git a/examples/example_learn_samples_multires.py b/examples/example_learn_samples_multires.py index b0895388..f84737dd 100644 --- a/examples/example_learn_samples_multires.py +++ b/examples/example_learn_samples_multires.py @@ -45,6 +45,7 @@ import joblib import matplotlib.pyplot as plt import numpy as np +import tempfile as tmp import torch from PIL import Image, ImageSequence from tqdm import tqdm @@ -269,8 +270,7 @@ def upsample_optimizer(optimizer, new_optimizer, factor=2): for param in model.parameters(): param.clamp_(-0.5, 0.5) # Generate images for gif - hashed = joblib.hash((i, "learn_traj", time.time())) - filename = "/tmp/" + f"{hashed}.png" + filename = f"{tmp.NamedTemporaryFile().name}.png" plt.clf() fig, axs = plt.subplots(2, 2, figsize=(10, 10), num=1) plot_state(