From 8c10b7dd327cebb7260e0400d21f4acd3fe10c50 Mon Sep 17 00:00:00 2001 From: Nick Tustison Date: Fri, 30 Dec 2022 16:25:01 -0800 Subject: [PATCH] ENH: average possible overlapping points. --- ...splineDisplacementFieldToScatteredData.cxx | 41 ++++++++++++------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/ants/lib/LOCAL_fitBsplineDisplacementFieldToScatteredData.cxx b/ants/lib/LOCAL_fitBsplineDisplacementFieldToScatteredData.cxx index 7753f58c..443ece4a 100644 --- a/ants/lib/LOCAL_fitBsplineDisplacementFieldToScatteredData.cxx +++ b/ants/lib/LOCAL_fitBsplineDisplacementFieldToScatteredData.cxx @@ -110,12 +110,22 @@ py::capsule fitBsplineVectorImageToScatteredDataHelper( weightImage->Allocate(); weightImage->FillBuffer( 0.0 ); + WeightImagePointerType countImage = WeightImageType::New(); + countImage->SetOrigin( bsplineFilter->GetBSplineDomainOrigin() ); + countImage->SetSpacing( bsplineFilter->GetBSplineDomainSpacing() ); + countImage->SetDirection( bsplineFilter->GetBSplineDomainDirection() ); + countImage->SetRegions( bsplineFilter->GetBSplineDomainSize() ); + countImage->Allocate(); + countImage->FillBuffer( 0.0 ); + + VectorType zeroVector( 0.0 ); ITKFieldPointerType rasterizedField = ITKFieldType::New(); rasterizedField->SetOrigin( bsplineFilter->GetBSplineDomainOrigin() ); rasterizedField->SetSpacing( bsplineFilter->GetBSplineDomainSpacing() ); rasterizedField->SetDirection( bsplineFilter->GetBSplineDomainDirection() ); rasterizedField->SetRegions( bsplineFilter->GetBSplineDomainSize() ); rasterizedField->Allocate(); + rasterizedField->FillBuffer( zeroVector ); for( unsigned int n = 0; n < numberOfPoints; n++ ) { @@ -126,34 +136,37 @@ py::capsule fitBsplineVectorImageToScatteredDataHelper( imagePoint[d] = displacementOriginsP(n, d); imageDisplacement[d] = displacementsP(n, d); } - typename ITKFieldType::IndexType imageIndex = + typename ITKFieldType::IndexType imageIndex = weightImage->TransformPhysicalPointToIndex( imagePoint ); - weightImage->SetPixel( imageIndex, displacementWeights[n] ); - rasterizedField->SetPixel( imageIndex, imageDisplacement ); + weightImage->SetPixel( imageIndex, displacementWeights[n] + weightImage->GetPixel( imageIndex ) ); + rasterizedField->SetPixel( imageIndex, imageDisplacement + rasterizedField->GetPixel( imageIndex ) ); + countImage->SetPixel( imageIndex, 1.0 + countImage->GetPixel( imageIndex ) ); } // Second, iterate through the weight image and pull those indices/points which have non-zero weights. - unsigned count = 0; + unsigned count = 0; - typename itk::ImageRegionIteratorWithIndex - ItW( weightImage, weightImage->GetLargestPossibleRegion() ); - for( ItW.GoToBegin(); ! ItW.IsAtEnd(); ++ItW ) + typename itk::ImageRegionIteratorWithIndex + ItC( countImage, countImage->GetLargestPossibleRegion() ); + for( ItC.GoToBegin(); ! ItC.IsAtEnd(); ++ItC ) { - if( ItW.Get() > 0.0 ) + if( ItC.Get() > 0.0 ) { typename ITKFieldType::PointType imagePoint; - weightImage->TransformIndexToPhysicalPoint( ItW.GetIndex(), imagePoint ); + weightImage->TransformIndexToPhysicalPoint( ItC.GetIndex(), imagePoint ); typename PointSetType::PointType point; point.CastFrom( imagePoint ); pointSet->SetPoint( count, point ); - pointSet->SetPointData( count, rasterizedField->GetPixel( ItW.GetIndex() ) ); - weights->InsertElement( count, ItW.Get() ); + VectorType imageDisplacement = rasterizedField->GetPixel( ItC.GetIndex() ) / ItC.Get(); + RealType weight = weightImage->GetPixel( ItC.GetIndex() ) / ItC.Get(); + pointSet->SetPointData( count, imageDisplacement ); + weights->InsertElement( count, weight ); count++; } - } - } - else + } + } + else { for( unsigned int n = 0; n < numberOfPoints; n++ ) {