From d20897f4ea8593ebe6c40335edccd0d6db1305a4 Mon Sep 17 00:00:00 2001 From: Nick Tustison Date: Fri, 12 Apr 2024 16:12:46 -0700 Subject: [PATCH] ENH: Add masked/bspline fitting variant of Nyul histogram matching. --- ants/utils/__init__.py | 2 +- ants/utils/histogram_match_image.py | 108 +++++++++++++++++++++++++++- 2 files changed, 106 insertions(+), 4 deletions(-) diff --git a/ants/utils/__init__.py b/ants/utils/__init__.py index 0b379019..3f3faadb 100644 --- a/ants/utils/__init__.py +++ b/ants/utils/__init__.py @@ -16,7 +16,7 @@ from .get_mask import get_mask from .get_neighborhood import (get_neighborhood_in_mask, get_neighborhood_at_voxel) -from .histogram_match_image import histogram_match_image +from .histogram_match_image import histogram_match_image, histogram_match_image2 from .histogram_equalize_image import histogram_equalize_image from .hausdorff_distance import hausdorff_distance from .image_similarity import image_similarity diff --git a/ants/utils/histogram_match_image.py b/ants/utils/histogram_match_image.py index 8fd1fd66..fb418380 100644 --- a/ants/utils/histogram_match_image.py +++ b/ants/utils/histogram_match_image.py @@ -1,11 +1,14 @@ -__all__ = ['histogram_match_image'] +__all__ = ['histogram_match_image', + 'histogram_match_image2'] -import math +import numpy as np -from ..core import ants_image as iio +from ..core import ants_image_io as iio from .. import utils +from ..utils import fit_bspline_object_to_scattered_data + def histogram_match_image(source_image, reference_image, number_of_histogram_bins=255, number_of_match_points=64, use_threshold_at_mean_intensity=False): """ @@ -51,3 +54,102 @@ def histogram_match_image(source_image, reference_image, number_of_histogram_bin return new_image +def histogram_match_image2(source_image, reference_image, + source_mask=None, reference_mask=None, + match_points=64, + transform_domain_size=255): + """ + Transform image intensities based on histogram mapping. + + Apply B-spline 1-D maps to an input image for intensity warping. + + Arguments + --------- + source_image : ANTsImage + source image + + reference_image : ANTsImage + reference image + + source_mask : ANTsImage + source mask + + reference_mask : ANTsImage + reference mask + + match_points : integer or tuple + Parametric points at which the intensity transform displacements are + specified between [0, 1], i.e. quantiles. Alternatively, a single number + can be given and the sequence is linearly spaced in [0, 1]. + + transform_domain_size : integer + Defines the sampling resolution of the B-spline warping. + + Returns + ------- + ANTs image + + Example + ------- + >>> import ants + >>> src_img = ants.image_read(ants.get_data('r16')) + >>> ref_img = ants.image_read(ants.get_data('r64')) + >>> src_ref = ants.histogram_match_image(src_img, ref_img) + """ + + if not isinstance(match_points, int): + if any(b < 0 for b in match_points) and any(b > 1 for b in match_points): + raise ValueError("If specifying match_points as a vector, values must be in the range [0, 1]") + + # Use entire image if mask isn't specified + if source_mask is None: + source_mask = source_image * 0 + 1 + if reference_mask is None: + reference_mask = reference_image * 0 + 1 + + source_array = source_image.numpy() + source_mask_array = source_mask.numpy() + source_masked_min = source_image[source_mask != 0].min() + source_masked_max = source_image[source_mask != 0].max() + + reference_array = reference_image.numpy() + reference_mask_array = reference_mask.numpy() + + parametric_points = None + if not isinstance(match_points, int): + parametric_points = match_points + else: + parametric_points = np.linspace(0, 1, match_points) + + source_intensity_quantiles = np.quantile(source_array[source_mask_array != 0], parametric_points) + reference_intensity_quantiles = np.quantile(reference_array[reference_mask_array != 0], parametric_points) + displacements = reference_intensity_quantiles - source_intensity_quantiles + + scattered_data = np.reshape(displacements, (len(displacements), 1)) + parametric_data = np.reshape(parametric_points * (source_masked_max - source_masked_min) + source_masked_min, (len(parametric_points), 1)) + + transform_domain_origin = source_masked_min + transform_domain_spacing = (source_masked_max - transform_domain_origin) / (transform_domain_size - 1) + + bspline_histogram_transform = fit_bspline_object_to_scattered_data(scattered_data, + parametric_data, [transform_domain_origin], [transform_domain_spacing], [transform_domain_size], + data_weights=None, is_parametric_dimension_closed=None, number_of_fitting_levels=8, + mesh_size=1, spline_order=3) + + transform_domain = np.linspace(source_masked_min, source_masked_max, transform_domain_size) + + transformed_source_array = source_image.numpy() + for i in range(len(transform_domain) - 1): + indices = np.where((source_array >= transform_domain[i]) & (source_array < transform_domain[i+1])) + intensities = source_array[indices] + + alpha = (intensities - transform_domain[i])/(transform_domain[i+1] - transform_domain[i]) + xfrm = alpha * (bspline_histogram_transform[i+1] - bspline_histogram_transform[i]) + bspline_histogram_transform[i] + transformed_source_array[indices] = intensities + xfrm + + transformed_source_image = iio.from_numpy(transformed_source_array, origin=source_image.origin, + spacing=source_image.spacing, direction=source_image.direction) + transformed_source_image[source_mask == 0] = source_image[source_mask == 0] + + return(transformed_source_image) +