-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #460 from ANTsX/TPS
ENH: Add thin-plate splines.
- Loading branch information
Showing
6 changed files
with
282 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
149 changes: 149 additions & 0 deletions
149
ants/lib/LOCAL_fitThinPlateSplineDisplacementFieldToScatteredData.cxx
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
#include <pybind11/numpy.h> | ||
|
||
#include <exception> | ||
#include <vector> | ||
#include <string> | ||
|
||
#include "itkImage.h" | ||
#include "itkPointSet.h" | ||
#include "itkThinPlateSplineKernelTransform.h" | ||
#include "itkImageRegionIteratorWithIndex.h" | ||
|
||
#include "LOCAL_antsImage.h" | ||
|
||
namespace py = pybind11; | ||
|
||
template<unsigned int Dimension> | ||
py::capsule fitThinPlateSplineVectorImageToScatteredDataHelper( | ||
py::array_t<double> displacementOrigins, | ||
py::array_t<double> displacements, | ||
py::array_t<double> origin, | ||
py::array_t<double> spacing, | ||
py::array_t<unsigned int> size, | ||
py::array_t<double> direction | ||
) | ||
{ | ||
using RealType = float; | ||
|
||
using ANTsFieldType = itk::VectorImage<RealType, Dimension>; | ||
using ANTsFieldPointerType = typename ANTsFieldType::Pointer; | ||
|
||
using VectorType = itk::Vector<RealType, Dimension>; | ||
|
||
using ITKFieldType = itk::Image<VectorType, Dimension>; | ||
using IteratorType = itk::ImageRegionIteratorWithIndex<ITKFieldType>; | ||
|
||
using CoordinateRepType = float; | ||
using TransformType = itk::ThinPlateSplineKernelTransform<CoordinateRepType, Dimension>; | ||
using PointType = itk::Point<CoordinateRepType, Dimension>; | ||
using PointSetType = typename TransformType::PointSetType; | ||
|
||
auto tps = TransformType::New(); | ||
|
||
//////////////////////////// | ||
// | ||
// Define the output thin-plate spline field domain | ||
// | ||
|
||
auto field = ITKFieldType::New(); | ||
|
||
auto originP = origin.unchecked<1>(); | ||
auto spacingP = spacing.unchecked<1>(); | ||
auto sizeP = size.unchecked<1>(); | ||
auto directionP = direction.unchecked<2>(); | ||
|
||
if( originP.shape(0) == 0 || sizeP.shape(0) == 0 || spacingP.shape(0) == 0 || directionP.shape(0) == 0 ) | ||
{ | ||
throw std::invalid_argument( "Thin-plate spline domain is not specified." ); | ||
} | ||
else | ||
{ | ||
typename ITKFieldType::PointType fieldOrigin; | ||
typename ITKFieldType::SpacingType fieldSpacing; | ||
typename ITKFieldType::SizeType fieldSize; | ||
typename ITKFieldType::DirectionType fieldDirection; | ||
|
||
for( unsigned int d = 0; d < Dimension; d++ ) | ||
{ | ||
fieldOrigin[d] = originP(d); | ||
fieldSpacing[d] = spacingP(d); | ||
fieldSize[d] = sizeP(d); | ||
for( unsigned int e = 0; e < Dimension; e++ ) | ||
{ | ||
fieldDirection(d, e) = directionP(d, e); | ||
} | ||
} | ||
field->SetRegions( fieldSize ); | ||
field->SetOrigin( fieldOrigin ); | ||
field->SetSpacing( fieldSpacing ); | ||
field->SetDirection( fieldDirection ); | ||
field->Allocate(); | ||
} | ||
|
||
auto sourceLandmarks = PointSetType::New(); | ||
auto targetLandmarks = PointSetType::New(); | ||
typename PointSetType::PointsContainer::Pointer sourceLandmarkContainer = sourceLandmarks->GetPoints(); | ||
typename PointSetType::PointsContainer::Pointer targetLandmarkContainer = targetLandmarks->GetPoints(); | ||
|
||
PointType sourcePoint; | ||
PointType targetPoint; | ||
|
||
auto displacementOriginsP = displacementOrigins.unchecked<2>(); | ||
auto displacementsP = displacements.unchecked<2>(); | ||
unsigned int numberOfPoints = displacementsP.shape(0); | ||
|
||
for( unsigned int n = 0; n < numberOfPoints; n++ ) | ||
{ | ||
for( unsigned int d = 0; d < Dimension; d++ ) | ||
{ | ||
sourcePoint[d] = displacementOriginsP(n, d); | ||
targetPoint[d] = displacementOriginsP(n, d) + displacementsP(n, d); | ||
} | ||
sourceLandmarkContainer->InsertElement( n, sourcePoint ); | ||
targetLandmarkContainer->InsertElement( n, targetPoint ); | ||
} | ||
|
||
tps->SetSourceLandmarks( sourceLandmarks ); | ||
tps->SetTargetLandmarks( targetLandmarks ); | ||
tps->ComputeWMatrix(); | ||
|
||
////////////////////////// | ||
// | ||
// Now convert back to vector image type. | ||
// | ||
|
||
ANTsFieldPointerType antsField = ANTsFieldType::New(); | ||
antsField->CopyInformation( field ); | ||
antsField->SetRegions( field->GetRequestedRegion() ); | ||
antsField->SetVectorLength( Dimension ); | ||
antsField->Allocate(); | ||
|
||
typename TransformType::InputPointType source; | ||
typename TransformType::OutputPointType target; | ||
|
||
IteratorType It( field, field->GetLargestPossibleRegion() ); | ||
for( It.GoToBegin(); !It.IsAtEnd(); ++It ) | ||
{ | ||
field->TransformIndexToPhysicalPoint( It.GetIndex(), source ); | ||
target = tps->TransformPoint( source ); | ||
|
||
typename ANTsFieldType::PixelType antsVector( Dimension ); | ||
for( unsigned int d = 0; d < Dimension; d++ ) | ||
{ | ||
antsVector[d] = target[d] - source[d]; | ||
} | ||
antsField->SetPixel( It.GetIndex(), antsVector ); | ||
} | ||
|
||
return wrap< ANTsFieldType >( antsField ); | ||
} | ||
|
||
PYBIND11_MODULE(fitThinPlateSplineDisplacementFieldToScatteredData, m) | ||
{ | ||
m.def("fitThinPlateSplineDisplacementFieldToScatteredDataD2", &fitThinPlateSplineVectorImageToScatteredDataHelper<2>); | ||
m.def("fitThinPlateSplineDisplacementFieldToScatteredDataD3", &fitThinPlateSplineVectorImageToScatteredDataHelper<3>); | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
__all__ = ["fit_thin_plate_spline_displacement_field"] | ||
|
||
import numpy as np | ||
|
||
from ..core import ants_image as iio | ||
from .. import core | ||
from .. import utils | ||
|
||
|
||
def fit_thin_plate_spline_displacement_field(displacement_origins=None, | ||
displacements=None, | ||
origin=None, | ||
spacing=None, | ||
size=None, | ||
direction=None): | ||
|
||
""" | ||
Fit a thin-plate spline object to a a set of points with associated displacements. | ||
This is basically a wrapper for the ITK filter | ||
https://itk.org/Doxygen/html/itkThinPlateSplineKernelTransform_8h.html | ||
ANTsR function: `fitThinPlateSplineToDisplacementField` | ||
Arguments | ||
--------- | ||
displacement_origins : 2-D numpy array | ||
Matrix (number_of_points x dimension) defining the origins of the input | ||
displacement points. Default = None. | ||
displacements : 2-D numpy array | ||
Matrix (number_of_points x dimension) defining the displacements of the input | ||
displacement points. Default = None. | ||
origin : n-D tuple | ||
Defines the physical origin of the B-spline object. | ||
spacing : n-D tuple | ||
Defines the physical spacing of the B-spline object. | ||
size : n-D tuple | ||
Defines the size (length) of the spline object. Note that the length of the | ||
spline object in dimension d is defined as spacing[d] * size[d]-1. | ||
direction : 2-D numpy array | ||
Booleans defining whether or not the corresponding parametric dimension is | ||
closed (e.g., closed loop). Default = None. | ||
Returns | ||
------- | ||
Returns an ANTsImage. | ||
Example | ||
------- | ||
>>> # Perform 2-D fitting | ||
>>> | ||
>>> import ants, numpy | ||
>>> | ||
>>> points = numpy.array([[-50, -50]]) | ||
>>> deltas = numpy.array([[10, 10]]) | ||
>>> | ||
>>> tps_field = ants.fit_thin_plate_spline_displacement_field( | ||
>>> displacement_origins=points, displacements=deltas, | ||
>>> origin=[0.0, 0.0], spacing=[1.0, 1.0], size=[100, 100], | ||
>>> direction=numpy.array([[-1, 0], [0, -1]])) | ||
""" | ||
|
||
dimensionality = displacement_origins.shape[1] | ||
if displacements.shape[1] != dimensionality: | ||
raise ValueError("Dimensionality between origins and displacements does not match.") | ||
|
||
if displacement_origins is None or displacement_origins is None: | ||
raise ValueError("Missing input. Input point set (origins + displacements) needs to be specified." ) | ||
|
||
if origin is not None and len(origin) != dimensionality: | ||
raise ValueError("Origin is not of length dimensionality.") | ||
|
||
if spacing is not None and len(spacing) != dimensionality: | ||
raise ValueError("Spacing is not of length dimensionality.") | ||
|
||
if size is not None and len(size) != dimensionality: | ||
raise ValueError("Size is not of length dimensionality.") | ||
|
||
if direction is not None and (direction.shape[0] != dimensionality and direction.shape[1] != dimensionality): | ||
raise ValueError("Direction is not of shape dimensionality x dimensionality.") | ||
|
||
# It would seem that pybind11 doesn't really play nicely when the | ||
# arguments are 'None' | ||
|
||
if origin is None: | ||
origin = np.empty(0) | ||
|
||
if spacing is None: | ||
spacing = np.empty(0) | ||
|
||
if size is None: | ||
size = np.empty(0) | ||
|
||
if direction is None: | ||
direction = np.empty((0, 0)) | ||
|
||
tps_field = None | ||
libfn = utils.get_lib_fn("fitThinPlateSplineDisplacementFieldToScatteredDataD%i" % (dimensionality)) | ||
tps_field = libfn(displacement_origins, displacements, origin, spacing, size, direction) | ||
|
||
tps_displacement_field = iio.ANTsImage(pixeltype='float', | ||
dimension=dimensionality, components=dimensionality, | ||
pointer=tps_field).clone('float') | ||
return tps_displacement_field | ||
|