From f17a8ad25a1271e50b48e686ed094ab08490b737 Mon Sep 17 00:00:00 2001 From: Simon Rit Date: Tue, 14 Nov 2023 12:05:23 +0100 Subject: [PATCH] ENH: Made weights of weighted least squares optional in conjugate gradient --- .../rtkconjugategradient.cxx | 24 ++++--------- ...gateGradientConeBeamReconstructionFilter.h | 5 +++ ...teGradientConeBeamReconstructionFilter.hxx | 34 ++++++++++++------- ...econstructionConjugateGradientOperator.hxx | 6 ++-- ...rtkconjugategradientreconstructiontest.cxx | 2 +- ...kcylindricaldetectorreconstructiontest.cxx | 2 +- 6 files changed, 39 insertions(+), 34 deletions(-) diff --git a/applications/rtkconjugategradient/rtkconjugategradient.cxx b/applications/rtkconjugategradient/rtkconjugategradient.cxx index 4ad79b186..2e8ff427b 100644 --- a/applications/rtkconjugategradient/rtkconjugategradient.cxx +++ b/applications/rtkconjugategradient/rtkconjugategradient.cxx @@ -77,25 +77,15 @@ main(int argc, char * argv[]) inputFilter = constantImageSource; } - // Read weights if given, otherwise default to weights all equal to one - itk::ImageSource::Pointer weightsSource; + // Read weights if given + OutputImageType::Pointer inputWeights; if (args_info.weights_given) { using WeightsReaderType = itk::ImageFileReader; WeightsReaderType::Pointer weightsReader = WeightsReaderType::New(); weightsReader->SetFileName(args_info.weights_arg); - weightsSource = weightsReader; - } - else - { - using ConstantWeightsSourceType = rtk::ConstantImageSource; - ConstantWeightsSourceType::Pointer constantWeightsSource = ConstantWeightsSourceType::New(); - - // Set the weights to be like the projections - TRY_AND_EXIT_ON_ITK_EXCEPTION(reader->UpdateOutputInformation()) - constantWeightsSource->SetInformationFromImage(reader->GetOutput()); - constantWeightsSource->SetConstant(1.0); - weightsSource = constantWeightsSource; + inputWeights = weightsReader->GetOutput(); + inputWeights->Update(); } // Read Support Mask if given @@ -110,9 +100,9 @@ main(int argc, char * argv[]) ConjugateGradientFilterType::Pointer conjugategradient = ConjugateGradientFilterType::New(); SetForwardProjectionFromGgo(args_info, conjugategradient.GetPointer()); SetBackProjectionFromGgo(args_info, conjugategradient.GetPointer()); - conjugategradient->SetInput(inputFilter->GetOutput()); - conjugategradient->SetInput(1, reader->GetOutput()); - conjugategradient->SetInput(2, weightsSource->GetOutput()); + conjugategradient->SetInputVolume(inputFilter->GetOutput()); + conjugategradient->SetInputProjectionStack(reader->GetOutput()); + conjugategradient->SetInputWeights(inputWeights); conjugategradient->SetCudaConjugateGradient(!args_info.nocudacg_flag); if (args_info.mask_given) { diff --git a/include/rtkConjugateGradientConeBeamReconstructionFilter.h b/include/rtkConjugateGradientConeBeamReconstructionFilter.h index 2b1fdc188..e0c09a79e 100644 --- a/include/rtkConjugateGradientConeBeamReconstructionFilter.h +++ b/include/rtkConjugateGradientConeBeamReconstructionFilter.h @@ -166,9 +166,14 @@ class ITK_TEMPLATE_EXPORT ConjugateGradientConeBeamReconstructionFilter std::is_same::value, CudaConstantVolumeSource, ConstantImageSource>::type ConstantImageSourceType; + typedef typename std::conditional::value && + std::is_same::value, + CudaConstantVolumeSource, + ConstantImageSource>::type ConstantWeightSourceType; #else using DisplacedDetectorFilterType = DisplacedDetectorImageFilter; using ConstantImageSourceType = ConstantImageSource; + using ConstantWeightSourceType = ConstantImageSource; #endif /** Set the support mask, if any, for support constraint in reconstruction */ diff --git a/include/rtkConjugateGradientConeBeamReconstructionFilter.hxx b/include/rtkConjugateGradientConeBeamReconstructionFilter.hxx index 246c58a9a..acd1c882e 100644 --- a/include/rtkConjugateGradientConeBeamReconstructionFilter.hxx +++ b/include/rtkConjugateGradientConeBeamReconstructionFilter.hxx @@ -20,6 +20,7 @@ #define rtkConjugateGradientConeBeamReconstructionFilter_hxx #include +#include namespace rtk { @@ -29,7 +30,7 @@ ConjugateGradientConeBeamReconstructionFilterSetNumberOfRequiredInputs(3); + this->SetNumberOfRequiredInputs(2); // Set the default values of member parameters m_NumberOfIterations = 3; @@ -74,7 +75,7 @@ void ConjugateGradientConeBeamReconstructionFilter::SetInputWeights( const TWeightsImage * weights) { - this->SetNthInput(2, const_cast(weights)); + this->SetInput("InputWeights", const_cast(weights)); } template @@ -104,7 +105,7 @@ template ::GetInputWeights() { - return static_cast(this->itk::ProcessObject::GetInput(2)); + return static_cast(this->itk::ProcessObject::GetInput("InputWeights")); } template @@ -142,11 +143,14 @@ ConjugateGradientConeBeamReconstructionFilterSetRequestedRegion(inputPtr1->GetLargestPossibleRegion()); - // Input 2 is the weights map on projections, either user-defined or filled with ones (default) - typename TWeightsImage::Pointer inputPtr2 = const_cast(this->GetInputWeights().GetPointer()); - if (!inputPtr2) - return; - inputPtr2->SetRequestedRegion(inputPtr2->GetLargestPossibleRegion()); + // Input "InputWeights" is the weights map on projections, either user-defined or filled with ones (default) + if (this->GetInputWeights().IsNotNull()) + { + typename TWeightsImage::Pointer inputWeights = const_cast(this->GetInputWeights().GetPointer()); + if (!inputWeights) + return; + inputWeights->SetRequestedRegion(inputWeights->GetLargestPossibleRegion()); + } // Input "SupportMask" is the support constraint mask on volume, if any if (this->GetSupportMask().IsNotNull()) @@ -190,6 +194,16 @@ ConjugateGradientConeBeamReconstructionFilterSetSupportMask(this->GetSupportMask()); m_ConjugateGradientFilter->SetX(this->GetInputVolume()); m_DisplacedDetectorFilter->SetDisable(m_DisableDisplacedDetectorFilter); + if (this->GetInputWeights().IsNull()) + { + using PixelType = typename TWeightsImage::PixelType; + using ComponentType = typename itk::PixelTraits::ValueType; + typename ConstantWeightSourceType::Pointer ones = ConstantWeightSourceType::New(); + ones->SetInformationFromImage(this->GetInputProjectionStack()); + ones->SetConstant(PixelType(itk::NumericTraits::One)); + ones->Update(); + this->SetInputWeights(ones->GetOutput()); + } m_DisplacedDetectorFilter->SetInput(this->GetInputWeights()); // Links with the m_BackProjectionFilter should be set here and not @@ -256,10 +270,6 @@ ConjugateGradientConeBeamReconstructionFilterGetSupportMask()) { m_MultiplyOutputFilter->Update(); - } - - if (this->GetSupportMask()) - { this->GraftOutput(m_MultiplyOutputFilter->GetOutput()); } else diff --git a/include/rtkReconstructionConjugateGradientOperator.hxx b/include/rtkReconstructionConjugateGradientOperator.hxx index f9694c7e9..20b2eb44e 100644 --- a/include/rtkReconstructionConjugateGradientOperator.hxx +++ b/include/rtkReconstructionConjugateGradientOperator.hxx @@ -161,10 +161,10 @@ ReconstructionConjugateGradientOperatorSetRequestedRegion(inputPtr1->GetLargestPossibleRegion()); // Input 2 is the weights map on projections, if any - typename TWeightsImage::Pointer inputPtr2 = const_cast(this->GetInputWeights().GetPointer()); - if (!inputPtr2) + typename TWeightsImage::Pointer inputWeights = const_cast(this->GetInputWeights().GetPointer()); + if (!inputWeights) return; - inputPtr2->SetRequestedRegion(inputPtr2->GetLargestPossibleRegion()); + inputWeights->SetRequestedRegion(inputWeights->GetLargestPossibleRegion()); // Input "SupportMask" is the support constraint mask on volume, if any if (this->GetSupportMask().IsNotNull()) diff --git a/test/rtkconjugategradientreconstructiontest.cxx b/test/rtkconjugategradientreconstructiontest.cxx index 507127d59..677b67b34 100644 --- a/test/rtkconjugategradientreconstructiontest.cxx +++ b/test/rtkconjugategradientreconstructiontest.cxx @@ -142,7 +142,7 @@ main(int, char **) std::cout << "\n\n****** Case 1: Voxel-Based Backprojector ******" << std::endl; conjugategradient->SetBackProjectionFilter(ConjugateGradientType::BP_VOXELBASED); - conjugategradient->SetInput(2, uniformWeightsSource->GetOutput()); + conjugategradient->SetInputWeights(uniformWeightsSource->GetOutput()); TRY_AND_EXIT_ON_ITK_EXCEPTION(conjugategradient->Update()); CheckImageQuality(conjugategradient->GetOutput(), dsl->GetOutput(), 0.08, 23, 2.0); diff --git a/test/rtkcylindricaldetectorreconstructiontest.cxx b/test/rtkcylindricaldetectorreconstructiontest.cxx index 277e915c2..6cba9c901 100644 --- a/test/rtkcylindricaldetectorreconstructiontest.cxx +++ b/test/rtkcylindricaldetectorreconstructiontest.cxx @@ -136,7 +136,7 @@ main(int, char **) ConjugateGradientType::Pointer conjugategradient = ConjugateGradientType::New(); conjugategradient->SetInput(tomographySource->GetOutput()); conjugategradient->SetInput(1, rei->GetOutput()); - conjugategradient->SetInput(2, uniformWeightsSource->GetOutput()); + conjugategradient->SetInputWeights(uniformWeightsSource->GetOutput()); conjugategradient->SetGeometry(geometry); conjugategradient->SetNumberOfIterations(5); conjugategradient->SetDisableDisplacedDetectorFilter(true);