diff --git a/examples/GPU/example_fastMRI_UNet.py b/examples/GPU/example_fastMRI_UNet.py index de877a1e..a0d118ca 100644 --- a/examples/GPU/example_fastMRI_UNet.py +++ b/examples/GPU/example_fastMRI_UNet.py @@ -4,30 +4,30 @@ Simple UNet model. ================== -This model is a simplified version of the U-Net architecture, -which is widely used for image segmentation tasks. -This is implemented in the proprietary FASTMRI package [fastmri]_. - -The U-Net model consists of an encoder (downsampling path) and -a decoder (upsampling path) with skip connections between corresponding -layers in the encoder and decoder. -These skip connections help in retaining spatial information +This model is a simplified version of the U-Net architecture, +which is widely used for image segmentation tasks. +This is implemented in the proprietary FASTMRI package [fastmri]_. + +The U-Net model consists of an encoder (downsampling path) and +a decoder (upsampling path) with skip connections between corresponding +layers in the encoder and decoder. +These skip connections help in retaining spatial information that is lost during the downsampling process. -The primary purpose of this model is to perform image reconstruction tasks, -specifically for MRI images. -It takes an input MRI image and reconstructs it to improve the image quality +The primary purpose of this model is to perform image reconstruction tasks, +specifically for MRI images. +It takes an input MRI image and reconstructs it to improve the image quality or to recover missing parts of the image. -This implementation of the UNet model was pulled from the FastMRI Facebook -repository, which is a collaborative research project aimed at advancing +This implementation of the UNet model was pulled from the FastMRI Facebook +repository, which is a collaborative research project aimed at advancing the field of medical imaging using machine learning techniques. .. math:: \mathbf{\hat{x}} = \mathrm{arg} \min_{\mathbf{x}} || \mathcal{U}_\mathbf{\theta}(\mathbf{y}) - \mathbf{x} ||_2^2 -where :math:`\mathbf{\hat{x}}` is the reconstructed MRI image, :math:`\mathbf{x}` is the ground truth image, +where :math:`\mathbf{\hat{x}}` is the reconstructed MRI image, :math:`\mathbf{x}` is the ground truth image, :math:`\mathbf{y}` is the input MRI image (e.g., k-space data), and :math:`\mathcal{U}_\mathbf{\theta}` is the U-Net model parameterized by :math:`\theta`. .. warning:: diff --git a/examples/GPU/example_learn_samples.py b/examples/GPU/example_learn_samples.py index c0e03a04..b65558fb 100644 --- a/examples/GPU/example_learn_samples.py +++ b/examples/GPU/example_learn_samples.py @@ -5,7 +5,7 @@ ====================== A small pytorch example to showcase learning k-space sampling patterns. -This example showcases the auto-diff capabilities of the NUFFT operator +This example showcases the auto-diff capabilities of the NUFFT operator wrt to k-space trajectory in mri-nufft. In this example, we solve the following optimization problem: @@ -13,7 +13,7 @@ .. math:: \mathbf{\hat{K}} = \mathrm{arg} \min_{\mathbf{K}} || \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} \mathbf{x} - \mathbf{x} ||_2^2 - + where :math:`\mathcal{F}_\mathbf{K}` is the forward NUFFT operator and :math:`D_\mathbf{K}` is the density compensators for trajectory :math:`\mathbf{K}`, :math:`\mathbf{x}` is the MR image which is also the target image to be reconstructed. .. warning:: diff --git a/examples/GPU/example_learn_samples_multicoil.py b/examples/GPU/example_learn_samples_multicoil.py index da8198ec..aa3a45e9 100644 --- a/examples/GPU/example_learn_samples_multicoil.py +++ b/examples/GPU/example_learn_samples_multicoil.py @@ -5,15 +5,15 @@ ========================================= A small pytorch example to showcase learning k-space sampling patterns. -This example showcases the auto-diff capabilities of the NUFFT operator +This example showcases the auto-diff capabilities of the NUFFT operator wrt to k-space trajectory in mri-nufft. Briefly, in this example we try to learn the k-space samples :math:`\mathbf{K}` for the following cost function: .. math:: - \mathbf{\hat{K}} = arg \min_{\mathbf{K}} || \sum_{\ell=1}^LS_\ell^* \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} x_\ell - \mathbf{x}_{sos} ||_2^2 - + \mathbf{\hat{K}} = arg \min_{\mathbf{K}} || \sum_{\ell=1}^LS_\ell^* \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} x_\ell - \mathbf{x}_{sos} ||_2^2 + where :math:`S_\ell` is the sensitivity map for the :math:`\ell`-th coil, :math:`\mathcal{F}_\mathbf{K}` is the forward NUFFT operator and :math:`D_\mathbf{K}` is the density compensators for trajectory :math:`\mathbf{K}`, :math:`\mathbf{x}_\ell` is the image for the :math:`\ell`-th coil, and :math:`\mathbf{x}_{sos} = \sqrt{\sum_{\ell=1}^L x_\ell^2}` is the sum-of-squares image as target image to be reconstructed. In this example, the forward NUFFT operator :math:`\mathcal{F}_\mathbf{K}` is implemented with `model.operator` while the SENSE operator :math:`model.sense_op` models the term :math:`\mathbf{A} = \sum_{\ell=1}^LS_\ell^* \mathcal{F}_\mathbf{K}^* D_\mathbf{K}`. @@ -21,8 +21,8 @@ .. note:: To showcase the features of ``mri-nufft``, we use `` - "cufinufft"`` backend for ``model.operator`` without density compensation and ``"gpunufft"`` backend for ``model.sense_op`` with density compensation. - + "cufinufft"`` backend for ``model.operator`` without density compensation and ``"gpunufft"`` backend for ``model.sense_op`` with density compensation. + .. warning:: This example only showcases the autodiff capabilities, the learned sampling pattern is not scanner compliant as the scanner gradients required to implement it violate the hardware constraints. In practice, a projection :math:`\Pi_\mathcal{Q}(\mathbf{K})` into the scanner constraints set :math:`\mathcal{Q}` is recommended (see [Proj]_). This is implemented in the proprietary SPARKLING package [Sparks]_. Users are encouraged to contact the authors if they want to use it. """ diff --git a/examples/example_learn_samples_multires.py b/examples/example_learn_samples_multires.py index f84737dd..abca55f5 100644 --- a/examples/example_learn_samples_multires.py +++ b/examples/example_learn_samples_multires.py @@ -13,7 +13,7 @@ .. math:: \mathbf{\hat{K}} = arg \min_{\mathbf{K}} || \mathcal{F}_\mathbf{K}^* D_\mathbf{K} \mathcal{F}_\mathbf{K} \mathbf{x} - \mathbf{x} ||_2^2 - + where :math:`\mathcal{F}_\mathbf{K}` is the forward NUFFT operator, :math:`D_\mathbf{K}` is the density compensator for trajectory :math:`\mathbf{K}`, and :math:`\mathbf{x}` is the MR image which is also the target image to be reconstructed. diff --git a/examples/example_trajectory_tools.py b/examples/example_trajectory_tools.py index 83afeb5d..8e4dfac3 100644 --- a/examples/example_trajectory_tools.py +++ b/examples/example_trajectory_tools.py @@ -219,6 +219,159 @@ axes=(0, 2), ) +# %% +# Stack Random +# ------------- +# +# A direct extension of the stacking expansion is to distribute the stacks +# according to a random distribution over the :math:`k_z`-axis. +# +# Arguments: +# - ``trajectory (array)``: array of k-space coordinates of size +# :math:`(N_c, N_s, N_d)` +# - ``dim_size (int)``: size of the kspace in voxel units +# - ``center_prop (int or float)`` : number of line +# - ``acceleration (int)``: Acceleration factor +# - ``pdf (str or array)``: Probability density function for the random distribution +# - ``rng (int or np.random.Generator)``: Random number generator +# - ``order (int)``: Order of the shots in the stack + + +trajectory = tools.stack_random( + planar_trajectories["Spiral"], + dim_size=128, + center_prop=0.1, + accel=16, + pdf="uniform", + order="top-down", + rng=42, +) + +show_trajectory(trajectory, figure_size=figure_size, one_shot=one_shot) + +# %% +# ``trajectory (array)`` +# ~~~~~~~~~~~~~~~~~~~~~~ +# The main use case is to stack trajectories consisting of +# flat or thick planes that will match the image slices. +arguments = ["Radial", "Spiral", "2D Cones", "3D Cones"] +function = lambda x: tools.stack_random( + planar_trajectories[x], + dim_size=128, + center_prop=0.1, + accel=16, + pdf="gaussian", + order="top-down", + rng=42, +) +show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size) + +# %% +# ``dim_size (int)`` +# ~~~~~~~~~~~~~~~~~~ +# Size of the k-space in voxel units over the stacking direction. It +# is used to normalize the stack positions, and is used with the ``accel`` +# factor and ``center_prop`` to determine the number of stacks. +arguments = [32, 64, 128] +function = lambda x: tools.stack_random( + planar_trajectories["Spiral"], + dim_size=x, + center_prop=0.1, + accel=8, + pdf="gaussian", + order="top-down", + rng=42, +) +show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size) + +# %% +# ``center_prop (int or float)`` +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# Number of lines to keep in the center of the k-space. It is used to determine +# the number of stacks and the acceleration factor, and to keep the center of +# the k-space with a higher density of shots. If a ``float`` this is a fraction +# of the total ``dim_size``. If ``int`` it is directly the number of lines. + +arguments = [1, 5, 0.1, 0.5] +function = lambda x: tools.stack_random( + planar_trajectories["Spiral"], + dim_size=128, + center_prop=x, + accel=16, + pdf="uniform", + order="top-down", + rng=42, +) +show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size) + + +# %% +# ``accel (int)`` +# ~~~~~~~~~~~~~~~ +# Acceleration factor to subsample the outer region of the k-space. +# Note that the acceleration factor does not take into account the center lines. + + +arguments = [1, 4, 8, 16, 32] +function = lambda x: tools.stack_random( + planar_trajectories["Spiral"], + dim_size=128, + center_prop=0.1, + accel=x, + pdf="uniform", + order="top-down", + rng=42, +) +show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size) + +# %% +# ``pdf (str or array)`` +# ~~~~~~~~~~~~~~~~~~~~~~ +# Probability density function for the sampling of the outer region. It can +# either be a string to use a known probability law ("gaussian" or "uniform") or +# "equispaced" for a coherent undersampling (like the one used in GRAPPA). It +# can also be a array, for using a customed density probability. +# In this case, it will be normalized so that ``sum(pdf) =1``. + +dim_size = 128 +arguments = [ + "gaussian", + "uniform", + "equispaced", + np.arange(dim_size), +] +function = lambda x: tools.stack_random( + planar_trajectories["Spiral"], + dim_size=128, + center_prop=0.1, + accel=32, + pdf=x, + order="top-down", + rng=42, +) +show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size) + +# %% +# ``order (str)`` +# ~~~~~~~~~~~~~~~ +# Determine the ordering of the shot in the trajectory. +# Accepeted values are "center-out", "top-down" or "random". +dim_size = 128 +arguments = [ + "center-out", + "random", + "top-down", +] +function = lambda x: tools.stack_random( + planar_trajectories["Spiral"], + dim_size=128, + center_prop=0.1, + accel=32, + pdf="uniform", + order=x, + rng=42, +) +show_trajectories(function, arguments, one_shot=one_shot, subfig_size=subfigure_size) # %% # Rotate diff --git a/src/mrinufft/trajectories/__init__.py b/src/mrinufft/trajectories/__init__.py index 77c588cd..e8c34a71 100644 --- a/src/mrinufft/trajectories/__init__.py +++ b/src/mrinufft/trajectories/__init__.py @@ -56,6 +56,12 @@ initialize_3D_wong_radial, ) +from .tools import ( + stack_random, + get_random_loc_1d, +) + + __all__ = [ # trajectories "initialize_2D_radial", @@ -88,7 +94,9 @@ "initialize_3D_random_walk", "initialize_3D_travelling_salesman", # tools + "get_random_loc_1d", "stack", + "stack_random", "rotate", "precess", "conify", diff --git a/src/mrinufft/trajectories/display.py b/src/mrinufft/trajectories/display.py index 42e07cbe..c848f450 100644 --- a/src/mrinufft/trajectories/display.py +++ b/src/mrinufft/trajectories/display.py @@ -57,6 +57,8 @@ class displayConfig: This can be any of the matplotlib colormaps, or a list of colors.""" one_shot_color: str = "k" """Matplotlib color for the highlighted shot, by default ``"k"`` (black).""" + one_shot_linewidth_factor: float = 2 + """Factor to multiply the linewidth of the highlighted shot, by default ``2``.""" gradient_point_color: str = "r" """Matplotlib color for gradient constraint points, by default ``"r"`` (red).""" slewrate_point_color: str = "b" @@ -243,7 +245,7 @@ def display_2D_trajectory( trajectory[shot_id, :, 0], trajectory[shot_id, :, 1], color=displayConfig.one_shot_color, - linewidth=2 * displayConfig.linewidth, + linewidth=displayConfig.one_shot_linewidth_factor * displayConfig.linewidth, ) # Point out violated constraints if requested @@ -379,7 +381,7 @@ def display_3D_trajectory( trajectory[shot_id, :, 1], trajectory[shot_id, :, 2], color=displayConfig.one_shot_color, - linewidth=2 * displayConfig.linewidth, + linewidth=displayConfig.one_shot_linewidth_factor * displayConfig.linewidth, ) trajectory = trajectory.reshape((-1, Nc, Ns, 3)) diff --git a/src/mrinufft/trajectories/tools.py b/src/mrinufft/trajectories/tools.py index 02ca23ae..b6490b02 100644 --- a/src/mrinufft/trajectories/tools.py +++ b/src/mrinufft/trajectories/tools.py @@ -5,9 +5,10 @@ import numpy as np from numpy.typing import NDArray from scipy.interpolate import CubicSpline, interp1d +from scipy.stats import norm from .maths import Rv, Rx, Ry, Rz -from .utils import KMAX, initialize_tilt +from .utils import KMAX, initialize_tilt, VDSpdf, VDSorder ################ # DIRECT TOOLS # @@ -784,3 +785,183 @@ def radialize_center( if in_out: return _radialize_in_out(trajectory, nb_samples) return _radialize_center_out(trajectory, nb_samples) + + +################# +# Randomization # +################# + + +def _flip2center(mask_cols: list[int], center_value: int) -> np.ndarray: + """ + Reorder a list by starting by a center_position and alternating left/right. + + Parameters + ---------- + mask_cols: list or np.array + List of columns to reorder. + center_pos: int + Position of the center column. + + Returns + ------- + np.array: reordered columns. + """ + center_pos = np.argmin(np.abs(np.array(mask_cols) - center_value)) + mask_cols = list(mask_cols) + left = mask_cols[center_pos::-1] + right = mask_cols[center_pos + 1 :] + new_cols = [] + while left or right: + if left: + new_cols.append(left.pop(0)) + if right: + new_cols.append(right.pop(0)) + return np.array(new_cols) + + +def get_random_loc_1d( + dim_size: int, + center_prop: float | int, + accel: float = 4, + pdf: Literal["uniform", "gaussian", "equispaced"] | NDArray = "uniform", + rng: int | np.random.Generator | None = None, + order: Literal["center-out", "top-down", "random"] = "center-out", +) -> NDArray: + """Get slice index at a random position. + + Parameters + ---------- + dim_size: int + Dimension size + center_prop: float or int + Proportion of center of kspace to continuouly sample + accel: float + Undersampling/Acceleration factor + pdf: str, optional + Probability density function for the remaining samples. + "gaussian" (default) or "uniform" or np.array + rng: int or np.random.Generator + random state + order: str + Order of the lines, "center-out" (default), "random" or "top-down" + + Returns + ------- + np.ndarray: array of size dim_size/accel. + """ + order = VDSorder(order) + pdf = VDSpdf(pdf) if isinstance(pdf, str) else pdf + if accel == 0 or accel == 1: + return np.arange(dim_size) # type: ignore + elif accel < 0: + raise ValueError("acceleration factor should be positive.") + elif isinstance(accel, float): + raise ValueError("acceleration factor should be an integer.") + + indexes = list(range(dim_size)) + + if not isinstance(center_prop, int): + center_prop = int(center_prop * dim_size) + + center_start = (dim_size - center_prop) // 2 + center_stop = (dim_size + center_prop) // 2 + center_indexes = indexes[center_start:center_stop] + borders = np.asarray([*indexes[:center_start], *indexes[center_stop:]]) + + n_samples_borders = (dim_size - len(center_indexes)) // accel + if n_samples_borders < 1: + raise ValueError( + "acceleration factor, center_prop and dimension not compatible." + "Edges will not be sampled. " + ) + rng = np.random.default_rng(rng) # get RNG from a seed or existing rng. + + def _get_samples(p: np.typing.ArrayLike) -> list[int]: + p = p / np.sum(p) # automatic casting if needed + return list(rng.choice(borders, size=n_samples_borders, replace=False, p=p)) + + if isinstance(pdf, np.ndarray): + if len(pdf) == dim_size: + # extract the borders + p = pdf[borders] + elif len(pdf) == len(borders): + p = pdf + else: + raise ValueError("Invalid size for probability.") + sampled_in_border = _get_samples(p) + + elif pdf == VDSpdf.GAUSSIAN: + p = norm.pdf(np.linspace(norm.ppf(0.001), norm.ppf(0.999), len(borders))) + sampled_in_border = _get_samples(p) + elif pdf == VDSpdf.UNIFORM: + p = np.ones(len(borders)) + sampled_in_border = _get_samples(p) + elif pdf == VDSpdf.EQUISPACED: + sampled_in_border = list(borders[::accel]) + + else: + raise ValueError("Unsupported value for pdf use any of . ") + # TODO: allow custom pdf as argument (vector or function.) + + line_locs = np.array(sorted(center_indexes + sampled_in_border)) + # apply order of lines + if order == VDSorder.CENTER_OUT: + line_locs = _flip2center(sorted(line_locs), dim_size // 2) + elif order == VDSorder.RANDOM: + line_locs = rng.permutation(line_locs) + elif order == VDSorder.TOP_DOWN: + line_locs = np.array(sorted(line_locs)) + else: + raise ValueError(f"Unknown direction '{order}'.") + return (line_locs / dim_size) * 2 * KMAX - KMAX # rescale to [-0.5,0.5] + + +def stack_random( + trajectory: NDArray, + dim_size: int, + center_prop: float | int = 0.0, + accel: float | int = 4, + pdf: Literal["uniform", "gaussian", "equispaced"] | NDArray = "uniform", + rng: int | np.random.Generator | None = None, + order: Literal["center-out", "top-down", "random"] = "center-out", +): + """Stack a 2D trajectory with random location. + + Parameters + ---------- + traj: np.ndarray + Existing 2D trajectory. + dim_size: int + Size of the k_z dimension + center_prop: int or float + Number of line or proportion of slice to sample in the center of the k-space + accel: int + Undersampling/Acceleration factor + pdf: str or np.array + Probability density function for the remaining samples. + "uniform" (default), "gaussian" or np.array + rng: random state + order: str + Order of the lines, "center-out" (default), "random" or "top-down" + + Returns + ------- + numpy.ndarray + The 3D trajectory stacked along the :math:`k_z` axis. + """ + line_locs = get_random_loc_1d(dim_size, center_prop, accel, pdf, rng, order) + if len(trajectory.shape) == 2: + Nc, Ns = 1, trajectory.shape[0] + else: + Nc, Ns = trajectory.shape[:2] + + new_trajectory = np.zeros((len(line_locs), Nc, Ns, 3)) + for i, loc in enumerate(line_locs): + new_trajectory[i, :, :, :2] = trajectory[..., :2] + if trajectory.shape[-1] == 3: + new_trajectory[i, :, :, 2] = trajectory[..., 2] + loc + else: + new_trajectory[i, :, :, 2] = loc + + return new_trajectory.reshape(-1, Ns, 3) diff --git a/src/mrinufft/trajectories/utils.py b/src/mrinufft/trajectories/utils.py index c1124e11..4f5a0e07 100644 --- a/src/mrinufft/trajectories/utils.py +++ b/src/mrinufft/trajectories/utils.py @@ -43,6 +43,12 @@ class FloatEnum(float, Enum, metaclass=CaseInsensitiveEnumMeta): pass +class StrEnum(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """An Enum for str that is case insensitive for its attributes.""" + + pass + + class Gammas(FloatEnum): """Enumerate gyromagnetic ratios for common nuclei in MR.""" @@ -94,7 +100,7 @@ class NormShapes(FloatEnum): OCTAHEDRON = L1 -class Tilts(str, Enum): +class Tilts(StrEnum): r"""Enumerate available tilts. Notes @@ -120,7 +126,7 @@ class Tilts(str, Enum): MRI = MRI_GOLDEN -class Packings(str, Enum, metaclass=CaseInsensitiveEnumMeta): +class Packings(StrEnum): """Enumerate available packing method for shots. It is mostly used for wave-CAIPI trajectory @@ -146,6 +152,27 @@ class Packings(str, Enum, metaclass=CaseInsensitiveEnumMeta): SPIRAL = FIBONACCI +############################# +# Variable Density Sampling # +############################# + + +class VDSorder(StrEnum): + """Available ordering for variable density sampling.""" + + CENTER_OUT = "center-out" + RANDOM = "random" + TOP_DOWN = "top-down" + + +class VDSpdf(StrEnum): + """Available law for variable density sampling.""" + + GAUSSIAN = "gaussian" + UNIFORM = "uniform" + EQUISPACED = "equispaced" + + ############### # CONSTRAINTS # ###############