Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random sampling #220

Merged
merged 14 commits into from
Jan 30, 2025
28 changes: 14 additions & 14 deletions examples/GPU/example_fastMRI_UNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,30 @@
Simple UNet model.
==================

This model is a simplified version of the U-Net architecture,
which is widely used for image segmentation tasks.
This is implemented in the proprietary FASTMRI package [fastmri]_.

The U-Net model consists of an encoder (downsampling path) and
a decoder (upsampling path) with skip connections between corresponding
layers in the encoder and decoder.
These skip connections help in retaining spatial information
This model is a simplified version of the U-Net architecture,
which is widely used for image segmentation tasks.
This is implemented in the proprietary FASTMRI package [fastmri]_.

The U-Net model consists of an encoder (downsampling path) and
a decoder (upsampling path) with skip connections between corresponding
layers in the encoder and decoder.
These skip connections help in retaining spatial information
that is lost during the downsampling process.

The primary purpose of this model is to perform image reconstruction tasks,
specifically for MRI images.
It takes an input MRI image and reconstructs it to improve the image quality
The primary purpose of this model is to perform image reconstruction tasks,
specifically for MRI images.
It takes an input MRI image and reconstructs it to improve the image quality
or to recover missing parts of the image.

This implementation of the UNet model was pulled from the FastMRI Facebook
repository, which is a collaborative research project aimed at advancing
This implementation of the UNet model was pulled from the FastMRI Facebook
repository, which is a collaborative research project aimed at advancing
the field of medical imaging using machine learning techniques.

.. math::

\mathbf{\hat{x}} = \mathrm{arg} \min_{\mathbf{x}} || \mathcal{U}_\mathbf{\theta}(\mathbf{y}) - \mathbf{x} ||_2^2

where :math:`\mathbf{\hat{x}}` is the reconstructed MRI image, :math:`\mathbf{x}` is the ground truth image,
where :math:`\mathbf{\hat{x}}` is the reconstructed MRI image, :math:`\mathbf{x}` is the ground truth image,
:math:`\mathbf{y}` is the input MRI image (e.g., k-space data), and :math:`\mathcal{U}_\mathbf{\theta}` is the U-Net model parameterized by :math:`\theta`.

.. warning::
Expand Down
4 changes: 2 additions & 2 deletions examples/GPU/example_learn_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
======================

A small pytorch example to showcase learning k-space sampling patterns.
This example showcases the auto-diff capabilities of the NUFFT operator
This example showcases the auto-diff capabilities of the NUFFT operator
wrt to k-space trajectory in mri-nufft.

In this example, we solve the following optimization problem:

.. math::

\mathbf{\hat{K}} = \mathrm{arg} \min_{\mathbf{K}} || \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} \mathbf{x} - \mathbf{x} ||_2^2

where :math:`\mathcal{F}_\mathbf{K}` is the forward NUFFT operator and :math:`D_\mathbf{K}` is the density compensators for trajectory :math:`\mathbf{K}`, :math:`\mathbf{x}` is the MR image which is also the target image to be reconstructed.

.. warning::
Expand Down
10 changes: 5 additions & 5 deletions examples/GPU/example_learn_samples_multicoil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,24 @@
=========================================

A small pytorch example to showcase learning k-space sampling patterns.
This example showcases the auto-diff capabilities of the NUFFT operator
This example showcases the auto-diff capabilities of the NUFFT operator
wrt to k-space trajectory in mri-nufft.

Briefly, in this example we try to learn the k-space samples :math:`\mathbf{K}` for the following cost function:

.. math::

\mathbf{\hat{K}} = arg \min_{\mathbf{K}} || \sum_{\ell=1}^LS_\ell^* \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} x_\ell - \mathbf{x}_{sos} ||_2^2
\mathbf{\hat{K}} = arg \min_{\mathbf{K}} || \sum_{\ell=1}^LS_\ell^* \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} x_\ell - \mathbf{x}_{sos} ||_2^2

where :math:`S_\ell` is the sensitivity map for the :math:`\ell`-th coil, :math:`\mathcal{F}_\mathbf{K}` is the forward NUFFT operator and :math:`D_\mathbf{K}` is the density compensators for trajectory :math:`\mathbf{K}`, :math:`\mathbf{x}_\ell` is the image for the :math:`\ell`-th coil, and :math:`\mathbf{x}_{sos} = \sqrt{\sum_{\ell=1}^L x_\ell^2}` is the sum-of-squares image as target image to be reconstructed.

In this example, the forward NUFFT operator :math:`\mathcal{F}_\mathbf{K}` is implemented with `model.operator` while the SENSE operator :math:`model.sense_op` models the term :math:`\mathbf{A} = \sum_{\ell=1}^LS_\ell^* \mathcal{F}_\mathbf{K}^* D_\mathbf{K}`.
For our data, we use a 2D slice of a 3D MRI image from the BrainWeb dataset, and the sensitivity maps are simulated using the `birdcage_maps` function from `sigpy.mri`.

.. note::
To showcase the features of ``mri-nufft``, we use ``
"cufinufft"`` backend for ``model.operator`` without density compensation and ``"gpunufft"`` backend for ``model.sense_op`` with density compensation.
"cufinufft"`` backend for ``model.operator`` without density compensation and ``"gpunufft"`` backend for ``model.sense_op`` with density compensation.

.. warning::
This example only showcases the autodiff capabilities, the learned sampling pattern is not scanner compliant as the scanner gradients required to implement it violate the hardware constraints. In practice, a projection :math:`\Pi_\mathcal{Q}(\mathbf{K})` into the scanner constraints set :math:`\mathcal{Q}` is recommended (see [Proj]_). This is implemented in the proprietary SPARKLING package [Sparks]_. Users are encouraged to contact the authors if they want to use it.
"""
Expand Down
2 changes: 1 addition & 1 deletion examples/example_learn_samples_multires.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

.. math::
\mathbf{\hat{K}} = arg \min_{\mathbf{K}} || \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} \mathbf{x} - \mathbf{x} ||_2^2

where :math:`\mathcal{F}_\mathbf{K}` is the forward NUFFT operator,
:math:`D_\mathbf{K}` is the density compensator for trajectory :math:`\mathbf{K}`,
and :math:`\mathbf{x}` is the MR image which is also the target image to be reconstructed.
Expand Down
153 changes: 153 additions & 0 deletions examples/example_trajectory_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,159 @@
axes=(0, 2),
)

# %%
paquiteau marked this conversation as resolved.
Show resolved Hide resolved
# Stack Random
# -------------
#
# A direct extension of the stacking expansion is to distribute the stacks
# according to a random distribution over the :math:`k_z`-axis.
#
# Arguments:
# - ``trajectory (array)``: array of k-space coordinates of size
# :math:`(N_c, N_s, N_d)`
# - ``dim_size (int)``: size of the kspace in voxel units
# - ``center_prop (int or float)`` : number of line
# - ``acceleration (int)``: Acceleration factor
# - ``pdf (str or array)``: Probability density function for the random distribution
# - ``rng (int or np.random.Generator)``: Random number generator
# - ``order (int)``: Order of the shots in the stack


trajectory = tools.stack_random(
planar_trajectories["Spiral"],
dim_size=128,
center_prop=0.1,
accel=16,
pdf="uniform",
order="top-down",
rng=42,
)

show_trajectory(trajectory, figure_size=figure_size, one_shot=one_shot)

# %%
# ``trajectory (array)``
# ~~~~~~~~~~~~~~~~~~~~~~
# The main use case is to stack trajectories consisting of
# flat or thick planes that will match the image slices.
arguments = ["Radial", "Spiral", "2D Cones", "3D Cones"]
function = lambda x: tools.stack_random(
planar_trajectories[x],
dim_size=128,
center_prop=0.1,
accel=16,
pdf="gaussian",
order="top-down",
rng=42,
)
show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size)

# %%
# ``dim_size (int)``
# ~~~~~~~~~~~~~~~~~~
# Size of the k-space in voxel units over the stacking direction. It
# is used to normalize the stack positions, and is used with the ``accel``
# factor and ``center_prop`` to determine the number of stacks.
arguments = [32, 64, 128]
function = lambda x: tools.stack_random(
planar_trajectories["Spiral"],
dim_size=x,
center_prop=0.1,
accel=8,
pdf="gaussian",
order="top-down",
rng=42,
)
show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size)

# %%
# ``center_prop (int or float)``
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Number of lines to keep in the center of the k-space. It is used to determine
# the number of stacks and the acceleration factor, and to keep the center of
# the k-space with a higher density of shots. If a ``float`` this is a fraction
# of the total ``dim_size``. If ``int`` it is directly the number of lines.

arguments = [1, 5, 0.1, 0.5]
function = lambda x: tools.stack_random(
planar_trajectories["Spiral"],
dim_size=128,
center_prop=x,
accel=16,
pdf="uniform",
order="top-down",
rng=42,
)
show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size)


# %%
# ``accel (int)``
# ~~~~~~~~~~~~~~~
# Acceleration factor to subsample the outer region of the k-space.
# Note that the acceleration factor does not take into account the center lines.


arguments = [1, 4, 8, 16, 32]
function = lambda x: tools.stack_random(
planar_trajectories["Spiral"],
dim_size=128,
center_prop=0.1,
accel=x,
pdf="uniform",
order="top-down",
rng=42,
)
show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size)

# %%
# ``pdf (str or array)``
# ~~~~~~~~~~~~~~~~~~~~~~
# Probability density function for the sampling of the outer region. It can
# either be a string to use a known probability law ("gaussian" or "uniform") or
# "equispaced" for a coherent undersampling (like the one used in GRAPPA). It
# can also be a array, for using a customed density probability.
# In this case, it will be normalized so that ``sum(pdf) =1``.

dim_size = 128
arguments = [
"gaussian",
"uniform",
"equispaced",
np.arange(dim_size),
]
function = lambda x: tools.stack_random(
planar_trajectories["Spiral"],
dim_size=128,
center_prop=0.1,
accel=32,
pdf=x,
order="top-down",
rng=42,
)
show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size)

# %%
# ``order (str)``
# ~~~~~~~~~~~~~~~
# Determine the ordering of the shot in the trajectory.
# Accepeted values are "center-out", "top-down" or "random".
dim_size = 128
arguments = [
"center-out",
"random",
"top-down",
]
function = lambda x: tools.stack_random(
planar_trajectories["Spiral"],
dim_size=128,
center_prop=0.1,
accel=32,
pdf="uniform",
order=x,
rng=42,
)
show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size)

# %%
# Rotate
Expand Down
8 changes: 8 additions & 0 deletions src/mrinufft/trajectories/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@
initialize_3D_wong_radial,
)

from .tools import (
stack_random,
get_random_loc_1d,
)


__all__ = [
# trajectories
"initialize_2D_radial",
Expand Down Expand Up @@ -88,7 +94,9 @@
"initialize_3D_random_walk",
"initialize_3D_travelling_salesman",
# tools
"get_random_loc_1d",
"stack",
"stack_random",
"rotate",
"precess",
"conify",
Expand Down
6 changes: 4 additions & 2 deletions src/mrinufft/trajectories/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class displayConfig:
This can be any of the matplotlib colormaps, or a list of colors."""
one_shot_color: str = "k"
"""Matplotlib color for the highlighted shot, by default ``"k"`` (black)."""
one_shot_linewidth_factor: float = 2
"""Factor to multiply the linewidth of the highlighted shot, by default ``2``."""
paquiteau marked this conversation as resolved.
Show resolved Hide resolved
gradient_point_color: str = "r"
"""Matplotlib color for gradient constraint points, by default ``"r"`` (red)."""
slewrate_point_color: str = "b"
Expand Down Expand Up @@ -243,7 +245,7 @@ def display_2D_trajectory(
trajectory[shot_id, :, 0],
trajectory[shot_id, :, 1],
color=displayConfig.one_shot_color,
linewidth=2 * displayConfig.linewidth,
linewidth=displayConfig.one_shot_linewidth_factor * displayConfig.linewidth,
)

# Point out violated constraints if requested
Expand Down Expand Up @@ -379,7 +381,7 @@ def display_3D_trajectory(
trajectory[shot_id, :, 1],
trajectory[shot_id, :, 2],
color=displayConfig.one_shot_color,
linewidth=2 * displayConfig.linewidth,
linewidth=displayConfig.one_shot_linewidth_factor * displayConfig.linewidth,
)
trajectory = trajectory.reshape((-1, Nc, Ns, 3))

Expand Down
Loading
Loading