Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add thin-plate splines. #460

Merged
merged 1 commit into from
May 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ants/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pybind11_add_module(cropImage LOCAL_cropImage.cxx)
pybind11_add_module(fitBsplineObjectToScatteredData LOCAL_fitBsplineObjectToScatteredData.cxx)
pybind11_add_module(fitBsplineDisplacementField LOCAL_fitBsplineDisplacementField.cxx)
pybind11_add_module(fitBsplineDisplacementFieldToScatteredData LOCAL_fitBsplineDisplacementFieldToScatteredData.cxx)
pybind11_add_module(fitThinPlateSplineDisplacementFieldToScatteredData LOCAL_fitThinPlateSplineDisplacementFieldToScatteredData.cxx)
pybind11_add_module(fsl2antstransform LOCAL_fsl2antstransform.cxx)
pybind11_add_module(getNeighborhoodMatrix LOCAL_getNeighborhoodMatrix.cxx)
pybind11_add_module(hausdorffDistance LOCAL_hausdorffDistance.cxx)
Expand Down Expand Up @@ -129,6 +130,7 @@ target_link_libraries(cropImage PRIVATE ${ITK_LIBRARIES})
target_link_libraries(fitBsplineObjectToScatteredData PRIVATE ${ITK_LIBRARIES})
target_link_libraries(fitBsplineDisplacementField PRIVATE ${ITK_LIBRARIES})
target_link_libraries(fitBsplineDisplacementFieldToScatteredData PRIVATE ${ITK_LIBRARIES})
target_link_libraries(fitThinPlateSplineDisplacementFieldToScatteredData PRIVATE ${ITK_LIBRARIES})
target_link_libraries(fsl2antstransform PRIVATE ${ITK_LIBRARIES})
target_link_libraries(getNeighborhoodMatrix PRIVATE ${ITK_LIBRARIES})
target_link_libraries(hausdorffDistance PRIVATE ${ITK_LIBRARIES})
Expand Down
149 changes: 149 additions & 0 deletions ants/lib/LOCAL_fitThinPlateSplineDisplacementFieldToScatteredData.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

#include <exception>
#include <vector>
#include <string>

#include "itkImage.h"
#include "itkPointSet.h"
#include "itkThinPlateSplineKernelTransform.h"
#include "itkImageRegionIteratorWithIndex.h"

#include "LOCAL_antsImage.h"

namespace py = pybind11;

template<unsigned int Dimension>
py::capsule fitThinPlateSplineVectorImageToScatteredDataHelper(
py::array_t<double> displacementOrigins,
py::array_t<double> displacements,
py::array_t<double> origin,
py::array_t<double> spacing,
py::array_t<unsigned int> size,
py::array_t<double> direction
)
{
using RealType = float;

using ANTsFieldType = itk::VectorImage<RealType, Dimension>;
using ANTsFieldPointerType = typename ANTsFieldType::Pointer;

using VectorType = itk::Vector<RealType, Dimension>;

using ITKFieldType = itk::Image<VectorType, Dimension>;
using IteratorType = itk::ImageRegionIteratorWithIndex<ITKFieldType>;

using CoordinateRepType = float;
using TransformType = itk::ThinPlateSplineKernelTransform<CoordinateRepType, Dimension>;
using PointType = itk::Point<CoordinateRepType, Dimension>;
using PointSetType = typename TransformType::PointSetType;

auto tps = TransformType::New();

////////////////////////////
//
// Define the output thin-plate spline field domain
//

auto field = ITKFieldType::New();

auto originP = origin.unchecked<1>();
auto spacingP = spacing.unchecked<1>();
auto sizeP = size.unchecked<1>();
auto directionP = direction.unchecked<2>();

if( originP.shape(0) == 0 || sizeP.shape(0) == 0 || spacingP.shape(0) == 0 || directionP.shape(0) == 0 )
{
throw std::invalid_argument( "Thin-plate spline domain is not specified." );
}
else
{
typename ITKFieldType::PointType fieldOrigin;
typename ITKFieldType::SpacingType fieldSpacing;
typename ITKFieldType::SizeType fieldSize;
typename ITKFieldType::DirectionType fieldDirection;

for( unsigned int d = 0; d < Dimension; d++ )
{
fieldOrigin[d] = originP(d);
fieldSpacing[d] = spacingP(d);
fieldSize[d] = sizeP(d);
for( unsigned int e = 0; e < Dimension; e++ )
{
fieldDirection(d, e) = directionP(d, e);
}
}
field->SetRegions( fieldSize );
field->SetOrigin( fieldOrigin );
field->SetSpacing( fieldSpacing );
field->SetDirection( fieldDirection );
field->Allocate();
}

auto sourceLandmarks = PointSetType::New();
auto targetLandmarks = PointSetType::New();
typename PointSetType::PointsContainer::Pointer sourceLandmarkContainer = sourceLandmarks->GetPoints();
typename PointSetType::PointsContainer::Pointer targetLandmarkContainer = targetLandmarks->GetPoints();

PointType sourcePoint;
PointType targetPoint;

auto displacementOriginsP = displacementOrigins.unchecked<2>();
auto displacementsP = displacements.unchecked<2>();
unsigned int numberOfPoints = displacementsP.shape(0);

for( unsigned int n = 0; n < numberOfPoints; n++ )
{
for( unsigned int d = 0; d < Dimension; d++ )
{
sourcePoint[d] = displacementOriginsP(n, d);
targetPoint[d] = displacementOriginsP(n, d) + displacementsP(n, d);
}
sourceLandmarkContainer->InsertElement( n, sourcePoint );
targetLandmarkContainer->InsertElement( n, targetPoint );
}

tps->SetSourceLandmarks( sourceLandmarks );
tps->SetTargetLandmarks( targetLandmarks );
tps->ComputeWMatrix();

//////////////////////////
//
// Now convert back to vector image type.
//

ANTsFieldPointerType antsField = ANTsFieldType::New();
antsField->CopyInformation( field );
antsField->SetRegions( field->GetRequestedRegion() );
antsField->SetVectorLength( Dimension );
antsField->Allocate();

typename TransformType::InputPointType source;
typename TransformType::OutputPointType target;

IteratorType It( field, field->GetLargestPossibleRegion() );
for( It.GoToBegin(); !It.IsAtEnd(); ++It )
{
field->TransformIndexToPhysicalPoint( It.GetIndex(), source );
target = tps->TransformPoint( source );

typename ANTsFieldType::PixelType antsVector( Dimension );
for( unsigned int d = 0; d < Dimension; d++ )
{
antsVector[d] = target[d] - source[d];
}
antsField->SetPixel( It.GetIndex(), antsVector );
}

return wrap< ANTsFieldType >( antsField );
}

PYBIND11_MODULE(fitThinPlateSplineDisplacementFieldToScatteredData, m)
{
m.def("fitThinPlateSplineDisplacementFieldToScatteredDataD2", &fitThinPlateSplineVectorImageToScatteredDataHelper<2>);
m.def("fitThinPlateSplineDisplacementFieldToScatteredDataD3", &fitThinPlateSplineVectorImageToScatteredDataHelper<3>);
}

1 change: 1 addition & 0 deletions ants/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .fitBsplineObjectToScatteredData import *
from .fitBsplineDisplacementField import *
from .fitBsplineDisplacementFieldToScatteredData import *
from .fitThinPlateSplineDisplacementFieldToScatteredData import *
from .fsl2antstransform import *
from .getNeighborhoodMatrix import *
from .hausdorffDistance import *
Expand Down
21 changes: 18 additions & 3 deletions ants/registration/landmark_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..core import ants_image_io as iio2
from ..utils import fit_bspline_displacement_field
from ..utils import fit_bspline_object_to_scattered_data
from ..utils import fit_thin_plate_spline_displacement_field
from ..utils import integrate_velocity_field
from ..utils import smooth_image
from ..utils import compose_displacement_fields
Expand Down Expand Up @@ -62,7 +63,7 @@ def fit_transform_to_paired_points(moving_points,
of points and d is the dimensionality.

transform_type : character
'rigid', 'similarity', "affine', 'bspline', 'diffeo', 'syn', or 'time-varying (tv)'.
'rigid', 'similarity', "affine', 'bspline', 'tps', 'diffeo', 'syn', or 'time-varying (tv)'.

regularization : scalar
Ridge penalty in [0,1] for linear transforms.
Expand Down Expand Up @@ -159,13 +160,13 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
has_components=True)
return(field)

allowed_transforms = ['rigid', 'affine', 'similarity', 'bspline', 'diffeo', 'syn', 'tv', 'time-varying']
allowed_transforms = ['rigid', 'affine', 'similarity', 'bspline', 'tps', 'diffeo', 'syn', 'tv', 'time-varying']
if not transform_type.lower() in allowed_transforms:
raise ValueError(transform_type + " transform not supported.")

transform_type = transform_type.lower()

if domain_image is None and transform_type in ['bspline', 'diffeo', 'syn', 'tv', 'time-varying']:
if domain_image is None and transform_type in ['bspline', 'tps', 'diffeo', 'syn', 'tv', 'time-varying']:
raise ValueError("Domain image needs to be specified.")

if not fixed_points.shape == moving_points.shape:
Expand Down Expand Up @@ -239,6 +240,20 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):

return xfrm

elif transform_type == "tps":

tps_displacement_field = fit_thin_plate_spline_displacement_field(
displacement_origins=fixed_points,
displacements=moving_points - fixed_points,
origin=domain_image.origin,
spacing=domain_image.spacing,
size=domain_image.shape,
direction=domain_image.direction)

xfrm = txio.transform_from_displacement_field(tps_displacement_field)

return xfrm

elif transform_type == "diffeo":

if verbose:
Expand Down
1 change: 1 addition & 0 deletions ants/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .denoise_image import *
from .fit_bspline_object_to_scattered_data import *
from .fit_bspline_displacement_field import *
from .fit_thin_plate_spline_displacement_field import *
from .get_ants_data import *
from .get_centroids import *
from .get_mask import *
Expand Down
111 changes: 111 additions & 0 deletions ants/utils/fit_thin_plate_spline_displacement_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
__all__ = ["fit_thin_plate_spline_displacement_field"]

import numpy as np

from ..core import ants_image as iio
from .. import core
from .. import utils


def fit_thin_plate_spline_displacement_field(displacement_origins=None,
displacements=None,
origin=None,
spacing=None,
size=None,
direction=None):

"""
Fit a thin-plate spline object to a a set of points with associated displacements.
This is basically a wrapper for the ITK filter

https://itk.org/Doxygen/html/itkThinPlateSplineKernelTransform_8h.html

ANTsR function: `fitThinPlateSplineToDisplacementField`

Arguments
---------

displacement_origins : 2-D numpy array
Matrix (number_of_points x dimension) defining the origins of the input
displacement points. Default = None.

displacements : 2-D numpy array
Matrix (number_of_points x dimension) defining the displacements of the input
displacement points. Default = None.

origin : n-D tuple
Defines the physical origin of the B-spline object.

spacing : n-D tuple
Defines the physical spacing of the B-spline object.

size : n-D tuple
Defines the size (length) of the spline object. Note that the length of the
spline object in dimension d is defined as spacing[d] * size[d]-1.

direction : 2-D numpy array
Booleans defining whether or not the corresponding parametric dimension is
closed (e.g., closed loop). Default = None.

Returns
-------
Returns an ANTsImage.

Example
-------
>>> # Perform 2-D fitting
>>>
>>> import ants, numpy
>>>
>>> points = numpy.array([[-50, -50]])
>>> deltas = numpy.array([[10, 10]])
>>>
>>> tps_field = ants.fit_thin_plate_spline_displacement_field(
>>> displacement_origins=points, displacements=deltas,
>>> origin=[0.0, 0.0], spacing=[1.0, 1.0], size=[100, 100],
>>> direction=numpy.array([[-1, 0], [0, -1]]))
"""

dimensionality = displacement_origins.shape[1]
if displacements.shape[1] != dimensionality:
raise ValueError("Dimensionality between origins and displacements does not match.")

if displacement_origins is None or displacement_origins is None:
raise ValueError("Missing input. Input point set (origins + displacements) needs to be specified." )

if origin is not None and len(origin) != dimensionality:
raise ValueError("Origin is not of length dimensionality.")

if spacing is not None and len(spacing) != dimensionality:
raise ValueError("Spacing is not of length dimensionality.")

if size is not None and len(size) != dimensionality:
raise ValueError("Size is not of length dimensionality.")

if direction is not None and (direction.shape[0] != dimensionality and direction.shape[1] != dimensionality):
raise ValueError("Direction is not of shape dimensionality x dimensionality.")

# It would seem that pybind11 doesn't really play nicely when the
# arguments are 'None'

if origin is None:
origin = np.empty(0)

if spacing is None:
spacing = np.empty(0)

if size is None:
size = np.empty(0)

if direction is None:
direction = np.empty((0, 0))

tps_field = None
libfn = utils.get_lib_fn("fitThinPlateSplineDisplacementFieldToScatteredDataD%i" % (dimensionality))
tps_field = libfn(displacement_origins, displacements, origin, spacing, size, direction)

tps_displacement_field = iio.ANTsImage(pixeltype='float',
dimension=dimensionality, components=dimensionality,
pointer=tps_field).clone('float')
return tps_displacement_field