Skip to content

Commit

Permalink
Merge pull request #460 from ANTsX/TPS
Browse files Browse the repository at this point in the history
ENH:  Add thin-plate splines.
  • Loading branch information
ntustison authored May 13, 2023
2 parents 719bee7 + 36a9649 commit 8d15800
Show file tree
Hide file tree
Showing 6 changed files with 282 additions and 3 deletions.
2 changes: 2 additions & 0 deletions ants/lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ pybind11_add_module(cropImage LOCAL_cropImage.cxx)
pybind11_add_module(fitBsplineObjectToScatteredData LOCAL_fitBsplineObjectToScatteredData.cxx)
pybind11_add_module(fitBsplineDisplacementField LOCAL_fitBsplineDisplacementField.cxx)
pybind11_add_module(fitBsplineDisplacementFieldToScatteredData LOCAL_fitBsplineDisplacementFieldToScatteredData.cxx)
pybind11_add_module(fitThinPlateSplineDisplacementFieldToScatteredData LOCAL_fitThinPlateSplineDisplacementFieldToScatteredData.cxx)
pybind11_add_module(fsl2antstransform LOCAL_fsl2antstransform.cxx)
pybind11_add_module(getNeighborhoodMatrix LOCAL_getNeighborhoodMatrix.cxx)
pybind11_add_module(hausdorffDistance LOCAL_hausdorffDistance.cxx)
Expand Down Expand Up @@ -129,6 +130,7 @@ target_link_libraries(cropImage PRIVATE ${ITK_LIBRARIES})
target_link_libraries(fitBsplineObjectToScatteredData PRIVATE ${ITK_LIBRARIES})
target_link_libraries(fitBsplineDisplacementField PRIVATE ${ITK_LIBRARIES})
target_link_libraries(fitBsplineDisplacementFieldToScatteredData PRIVATE ${ITK_LIBRARIES})
target_link_libraries(fitThinPlateSplineDisplacementFieldToScatteredData PRIVATE ${ITK_LIBRARIES})
target_link_libraries(fsl2antstransform PRIVATE ${ITK_LIBRARIES})
target_link_libraries(getNeighborhoodMatrix PRIVATE ${ITK_LIBRARIES})
target_link_libraries(hausdorffDistance PRIVATE ${ITK_LIBRARIES})
Expand Down
149 changes: 149 additions & 0 deletions ants/lib/LOCAL_fitThinPlateSplineDisplacementFieldToScatteredData.cxx
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>

#include <exception>
#include <vector>
#include <string>

#include "itkImage.h"
#include "itkPointSet.h"
#include "itkThinPlateSplineKernelTransform.h"
#include "itkImageRegionIteratorWithIndex.h"

#include "LOCAL_antsImage.h"

namespace py = pybind11;

template<unsigned int Dimension>
py::capsule fitThinPlateSplineVectorImageToScatteredDataHelper(
py::array_t<double> displacementOrigins,
py::array_t<double> displacements,
py::array_t<double> origin,
py::array_t<double> spacing,
py::array_t<unsigned int> size,
py::array_t<double> direction
)
{
using RealType = float;

using ANTsFieldType = itk::VectorImage<RealType, Dimension>;
using ANTsFieldPointerType = typename ANTsFieldType::Pointer;

using VectorType = itk::Vector<RealType, Dimension>;

using ITKFieldType = itk::Image<VectorType, Dimension>;
using IteratorType = itk::ImageRegionIteratorWithIndex<ITKFieldType>;

using CoordinateRepType = float;
using TransformType = itk::ThinPlateSplineKernelTransform<CoordinateRepType, Dimension>;
using PointType = itk::Point<CoordinateRepType, Dimension>;
using PointSetType = typename TransformType::PointSetType;

auto tps = TransformType::New();

////////////////////////////
//
// Define the output thin-plate spline field domain
//

auto field = ITKFieldType::New();

auto originP = origin.unchecked<1>();
auto spacingP = spacing.unchecked<1>();
auto sizeP = size.unchecked<1>();
auto directionP = direction.unchecked<2>();

if( originP.shape(0) == 0 || sizeP.shape(0) == 0 || spacingP.shape(0) == 0 || directionP.shape(0) == 0 )
{
throw std::invalid_argument( "Thin-plate spline domain is not specified." );
}
else
{
typename ITKFieldType::PointType fieldOrigin;
typename ITKFieldType::SpacingType fieldSpacing;
typename ITKFieldType::SizeType fieldSize;
typename ITKFieldType::DirectionType fieldDirection;

for( unsigned int d = 0; d < Dimension; d++ )
{
fieldOrigin[d] = originP(d);
fieldSpacing[d] = spacingP(d);
fieldSize[d] = sizeP(d);
for( unsigned int e = 0; e < Dimension; e++ )
{
fieldDirection(d, e) = directionP(d, e);
}
}
field->SetRegions( fieldSize );
field->SetOrigin( fieldOrigin );
field->SetSpacing( fieldSpacing );
field->SetDirection( fieldDirection );
field->Allocate();
}

auto sourceLandmarks = PointSetType::New();
auto targetLandmarks = PointSetType::New();
typename PointSetType::PointsContainer::Pointer sourceLandmarkContainer = sourceLandmarks->GetPoints();
typename PointSetType::PointsContainer::Pointer targetLandmarkContainer = targetLandmarks->GetPoints();

PointType sourcePoint;
PointType targetPoint;

auto displacementOriginsP = displacementOrigins.unchecked<2>();
auto displacementsP = displacements.unchecked<2>();
unsigned int numberOfPoints = displacementsP.shape(0);

for( unsigned int n = 0; n < numberOfPoints; n++ )
{
for( unsigned int d = 0; d < Dimension; d++ )
{
sourcePoint[d] = displacementOriginsP(n, d);
targetPoint[d] = displacementOriginsP(n, d) + displacementsP(n, d);
}
sourceLandmarkContainer->InsertElement( n, sourcePoint );
targetLandmarkContainer->InsertElement( n, targetPoint );
}

tps->SetSourceLandmarks( sourceLandmarks );
tps->SetTargetLandmarks( targetLandmarks );
tps->ComputeWMatrix();

//////////////////////////
//
// Now convert back to vector image type.
//

ANTsFieldPointerType antsField = ANTsFieldType::New();
antsField->CopyInformation( field );
antsField->SetRegions( field->GetRequestedRegion() );
antsField->SetVectorLength( Dimension );
antsField->Allocate();

typename TransformType::InputPointType source;
typename TransformType::OutputPointType target;

IteratorType It( field, field->GetLargestPossibleRegion() );
for( It.GoToBegin(); !It.IsAtEnd(); ++It )
{
field->TransformIndexToPhysicalPoint( It.GetIndex(), source );
target = tps->TransformPoint( source );

typename ANTsFieldType::PixelType antsVector( Dimension );
for( unsigned int d = 0; d < Dimension; d++ )
{
antsVector[d] = target[d] - source[d];
}
antsField->SetPixel( It.GetIndex(), antsVector );
}

return wrap< ANTsFieldType >( antsField );
}

PYBIND11_MODULE(fitThinPlateSplineDisplacementFieldToScatteredData, m)
{
m.def("fitThinPlateSplineDisplacementFieldToScatteredDataD2", &fitThinPlateSplineVectorImageToScatteredDataHelper<2>);
m.def("fitThinPlateSplineDisplacementFieldToScatteredDataD3", &fitThinPlateSplineVectorImageToScatteredDataHelper<3>);
}

1 change: 1 addition & 0 deletions ants/lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .fitBsplineObjectToScatteredData import *
from .fitBsplineDisplacementField import *
from .fitBsplineDisplacementFieldToScatteredData import *
from .fitThinPlateSplineDisplacementFieldToScatteredData import *
from .fsl2antstransform import *
from .getNeighborhoodMatrix import *
from .hausdorffDistance import *
Expand Down
21 changes: 18 additions & 3 deletions ants/registration/landmark_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ..core import ants_image_io as iio2
from ..utils import fit_bspline_displacement_field
from ..utils import fit_bspline_object_to_scattered_data
from ..utils import fit_thin_plate_spline_displacement_field
from ..utils import integrate_velocity_field
from ..utils import smooth_image
from ..utils import compose_displacement_fields
Expand Down Expand Up @@ -62,7 +63,7 @@ def fit_transform_to_paired_points(moving_points,
of points and d is the dimensionality.
transform_type : character
'rigid', 'similarity', "affine', 'bspline', 'diffeo', 'syn', or 'time-varying (tv)'.
'rigid', 'similarity', "affine', 'bspline', 'tps', 'diffeo', 'syn', or 'time-varying (tv)'.
regularization : scalar
Ridge penalty in [0,1] for linear transforms.
Expand Down Expand Up @@ -159,13 +160,13 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):
has_components=True)
return(field)

allowed_transforms = ['rigid', 'affine', 'similarity', 'bspline', 'diffeo', 'syn', 'tv', 'time-varying']
allowed_transforms = ['rigid', 'affine', 'similarity', 'bspline', 'tps', 'diffeo', 'syn', 'tv', 'time-varying']
if not transform_type.lower() in allowed_transforms:
raise ValueError(transform_type + " transform not supported.")

transform_type = transform_type.lower()

if domain_image is None and transform_type in ['bspline', 'diffeo', 'syn', 'tv', 'time-varying']:
if domain_image is None and transform_type in ['bspline', 'tps', 'diffeo', 'syn', 'tv', 'time-varying']:
raise ValueError("Domain image needs to be specified.")

if not fixed_points.shape == moving_points.shape:
Expand Down Expand Up @@ -239,6 +240,20 @@ def create_zero_velocity_field(domain_image, number_of_time_points=2):

return xfrm

elif transform_type == "tps":

tps_displacement_field = fit_thin_plate_spline_displacement_field(
displacement_origins=fixed_points,
displacements=moving_points - fixed_points,
origin=domain_image.origin,
spacing=domain_image.spacing,
size=domain_image.shape,
direction=domain_image.direction)

xfrm = txio.transform_from_displacement_field(tps_displacement_field)

return xfrm

elif transform_type == "diffeo":

if verbose:
Expand Down
1 change: 1 addition & 0 deletions ants/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .denoise_image import *
from .fit_bspline_object_to_scattered_data import *
from .fit_bspline_displacement_field import *
from .fit_thin_plate_spline_displacement_field import *
from .get_ants_data import *
from .get_centroids import *
from .get_mask import *
Expand Down
111 changes: 111 additions & 0 deletions ants/utils/fit_thin_plate_spline_displacement_field.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
__all__ = ["fit_thin_plate_spline_displacement_field"]

import numpy as np

from ..core import ants_image as iio
from .. import core
from .. import utils


def fit_thin_plate_spline_displacement_field(displacement_origins=None,
displacements=None,
origin=None,
spacing=None,
size=None,
direction=None):

"""
Fit a thin-plate spline object to a a set of points with associated displacements.
This is basically a wrapper for the ITK filter
https://itk.org/Doxygen/html/itkThinPlateSplineKernelTransform_8h.html
ANTsR function: `fitThinPlateSplineToDisplacementField`
Arguments
---------
displacement_origins : 2-D numpy array
Matrix (number_of_points x dimension) defining the origins of the input
displacement points. Default = None.
displacements : 2-D numpy array
Matrix (number_of_points x dimension) defining the displacements of the input
displacement points. Default = None.
origin : n-D tuple
Defines the physical origin of the B-spline object.
spacing : n-D tuple
Defines the physical spacing of the B-spline object.
size : n-D tuple
Defines the size (length) of the spline object. Note that the length of the
spline object in dimension d is defined as spacing[d] * size[d]-1.
direction : 2-D numpy array
Booleans defining whether or not the corresponding parametric dimension is
closed (e.g., closed loop). Default = None.
Returns
-------
Returns an ANTsImage.
Example
-------
>>> # Perform 2-D fitting
>>>
>>> import ants, numpy
>>>
>>> points = numpy.array([[-50, -50]])
>>> deltas = numpy.array([[10, 10]])
>>>
>>> tps_field = ants.fit_thin_plate_spline_displacement_field(
>>> displacement_origins=points, displacements=deltas,
>>> origin=[0.0, 0.0], spacing=[1.0, 1.0], size=[100, 100],
>>> direction=numpy.array([[-1, 0], [0, -1]]))
"""

dimensionality = displacement_origins.shape[1]
if displacements.shape[1] != dimensionality:
raise ValueError("Dimensionality between origins and displacements does not match.")

if displacement_origins is None or displacement_origins is None:
raise ValueError("Missing input. Input point set (origins + displacements) needs to be specified." )

if origin is not None and len(origin) != dimensionality:
raise ValueError("Origin is not of length dimensionality.")

if spacing is not None and len(spacing) != dimensionality:
raise ValueError("Spacing is not of length dimensionality.")

if size is not None and len(size) != dimensionality:
raise ValueError("Size is not of length dimensionality.")

if direction is not None and (direction.shape[0] != dimensionality and direction.shape[1] != dimensionality):
raise ValueError("Direction is not of shape dimensionality x dimensionality.")

# It would seem that pybind11 doesn't really play nicely when the
# arguments are 'None'

if origin is None:
origin = np.empty(0)

if spacing is None:
spacing = np.empty(0)

if size is None:
size = np.empty(0)

if direction is None:
direction = np.empty((0, 0))

tps_field = None
libfn = utils.get_lib_fn("fitThinPlateSplineDisplacementFieldToScatteredDataD%i" % (dimensionality))
tps_field = libfn(displacement_origins, displacements, origin, spacing, size, direction)

tps_displacement_field = iio.ANTsImage(pixeltype='float',
dimension=dimensionality, components=dimensionality,
pointer=tps_field).clone('float')
return tps_displacement_field

0 comments on commit 8d15800

Please sign in to comment.