diff --git a/source/tdis/CMakeLists.txt b/source/tdis/CMakeLists.txt index dd8145a..2e786bf 100644 --- a/source/tdis/CMakeLists.txt +++ b/source/tdis/CMakeLists.txt @@ -53,6 +53,11 @@ add_executable(tdis tracking/KalmanFittingFactory.cpp # tracking/CKFTracking.h # tracking/CKFTracking.cc + tracking/TrackFitterFunction.hpp + tracking/KalmanFitterFunction.cpp + tracking/RefittingCalibrator.hpp + tracking/RefittingCalibrator.cpp + ) # ---------- FIND REQUIRED PACKAGES ------------- @@ -66,6 +71,7 @@ find_package(Acts REQUIRED COMPONENTS Core PluginTGeo PluginJson) target_include_directories(tdis PUBLIC ${CMAKE_CURRENT_LIST_DIR} "${CMAKE_CURRENT_LIST_DIR}/..") target_link_libraries(tdis ${JANA_LIB} + libc++ ROOT::RIO ROOT::Core podio::podio podio::podioRootIO podio_model_lib podio_model_dict spdlog::spdlog diff --git a/source/tdis/tracking/DD4hepBField.cc b/source/tdis/tracking/DD4hepBField.cc deleted file mode 100644 index e60b62d..0000000 --- a/source/tdis/tracking/DD4hepBField.cc +++ /dev/null @@ -1,46 +0,0 @@ -// SPDX-License-Identifier: LGPL-3.0-or-later -// Copyright (C) 2022 Whitney Armstrong, Wouter Deconinck - -#include "DD4hepBField.h" - -#include -#include -#include -#include -#include -#include - -namespace eicrecon::BField { - - Acts::Result DD4hepBField::getField(const Acts::Vector3& position, - Acts::MagneticFieldProvider::Cache& /*cache*/) const - { - dd4hep::Position pos( - position[0] * (dd4hep::mm / Acts::UnitConstants::mm), - position[1] * (dd4hep::mm / Acts::UnitConstants::mm), - position[2] * (dd4hep::mm / Acts::UnitConstants::mm)); - - auto fieldObj = m_det->field(); - auto field = fieldObj.magneticField(pos) * (Acts::UnitConstants::T / dd4hep::tesla); - - // FIXME Acts doesn't seem to like exact zero components - if (field.x() * field.y() * field.z() == 0) { - static dd4hep::Direction epsilon{ - std::numeric_limits::epsilon(), - std::numeric_limits::epsilon(), - std::numeric_limits::epsilon() - }; - field += epsilon; - } - - return Acts::Result::success({field.x(), field.y(), field.z()}); - } - - Acts::Result DD4hepBField::getFieldGradient(const Acts::Vector3& position, - Acts::ActsMatrix<3, 3>& /*derivative*/, - Acts::MagneticFieldProvider::Cache& cache) const - { - return this->getField(position, cache); - } - -} // namespace eicrecon::BField diff --git a/source/tdis/tracking/DD4hepBField.h b/source/tdis/tracking/DD4hepBField.h deleted file mode 100644 index 979fd83..0000000 --- a/source/tdis/tracking/DD4hepBField.h +++ /dev/null @@ -1,90 +0,0 @@ -// This file is part of the Acts project. -// -// Copyright (C) 2016-2018 CERN for the benefit of the Acts project -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this -// file, You can obtain one at http://mozilla.org/MPL/2.0/. - -#pragma once - -#include -#include -#include -#include - -#include -#include -#include - - - - -namespace eicrecon::BField { - - ///// The Context to be handed around - //struct ScalableBFieldContext { - // double scalor = 1.; - //}; - - /** Use the dd4hep magnetic field in acts. - * - * \ingroup magnets - * \ingroup magsvc - */ - class DD4hepBField final : public Acts::MagneticFieldProvider { - public: - - - public: - struct Cache { - Cache(const Acts::MagneticFieldContext& /*mcfg*/) { } - }; - - Acts::MagneticFieldProvider::Cache makeCache(const Acts::MagneticFieldContext& mctx) const override - { - - return Acts::MagneticFieldProvider::Cache(std::in_place_type, mctx); - - } - - /** construct constant magnetic field from field vector. - * - */ - explicit DD4hepBField() { - - } - - /** retrieve magnetic field value. - * - * @param [in] position global position - * @param [in] cache Cache object (is ignored) - * @return magnetic field vector - * - * @note The @p position is ignored and only kept as argument to provide - * a consistent interface with other magnetic field services. - */ - Acts::Result getField(const Acts::Vector3& position, Acts::MagneticFieldProvider::Cache& cache) const override; - - /** @brief retrieve magnetic field value & its gradient - * - * @param [in] position global position - * @param [out] derivative gradient of magnetic field vector as (3x3) - * matrix - * @param [in] cache Cache object (is ignored) - * @return magnetic field vector - * - * @note The @p position is ignored and only kept as argument to provide - * a consistent interface with other magnetic field services. - * @note currently the derivative is not calculated - * @todo return derivative - */ - Acts::Result getFieldGradient(const Acts::Vector3& position, Acts::ActsMatrix<3, 3>& /*derivative*/, - Acts::MagneticFieldProvider::Cache& cache) const override; - }; - - using BFieldVariant = std::variant>; - - - -} // namespace eicrecon::BField diff --git a/source/tdis/tracking/KalmanFitterFunction.cpp b/source/tdis/tracking/KalmanFitterFunction.cpp new file mode 100644 index 0000000..b79dab0 --- /dev/null +++ b/source/tdis/tracking/KalmanFitterFunction.cpp @@ -0,0 +1,194 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "Acts/Definitions/Direction.hpp" +#include "Acts/Definitions/TrackParametrization.hpp" +#include "Acts/EventData/MultiTrajectory.hpp" +#include "Acts/EventData/TrackContainer.hpp" +#include "Acts/EventData/TrackStatePropMask.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/EventData/VectorTrackContainer.hpp" +#include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp" +#include "Acts/Geometry/GeometryIdentifier.hpp" +#include "Acts/Propagator/DirectNavigator.hpp" +#include "Acts/Propagator/Navigator.hpp" +#include "Acts/Propagator/Propagator.hpp" +#include "Acts/Propagator/SympyStepper.hpp" +#include "Acts/TrackFitting/GainMatrixSmoother.hpp" +#include "Acts/TrackFitting/GainMatrixUpdater.hpp" +#include "Acts/TrackFitting/KalmanFitter.hpp" +#include "Acts/Utilities/Delegate.hpp" +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/IndexSourceLink.hpp" +#include "ActsExamples/EventData/MeasurementCalibration.hpp" +#include "ActsExamples/EventData/Track.hpp" +#include "ActsExamples/TrackFitting/RefittingCalibrator.hpp" +#include "TrackFitterFunction.hpp" + +#include +#include +#include +#include +#include +#include + +namespace Acts { +class MagneticFieldProvider; +class SourceLink; +class Surface; +class TrackingGeometry; +} // namespace Acts + +namespace { + +using Stepper = Acts::SympyStepper; +using Propagator = Acts::Propagator; +using Fitter = Acts::KalmanFitter; +using DirectPropagator = Acts::Propagator; +using DirectFitter = + Acts::KalmanFitter; + +using TrackContainer = + Acts::TrackContainer; + +struct SimpleReverseFilteringLogic { + double momentumThreshold = 0; + + bool doBackwardFiltering( + Acts::VectorMultiTrajectory::ConstTrackStateProxy trackState) const { + auto momentum = std::abs(1 / trackState.filtered()[Acts::eBoundQOverP]); + return (momentum <= momentumThreshold); + } +}; + +using namespace ActsExamples; + +struct KalmanFitterFunctionImpl final : public TrackFitterFunction { + Fitter fitter; + DirectFitter directFitter; + + Acts::GainMatrixUpdater kfUpdater; + Acts::GainMatrixSmoother kfSmoother; + SimpleReverseFilteringLogic reverseFilteringLogic; + + bool multipleScattering = false; + bool energyLoss = false; + Acts::FreeToBoundCorrection freeToBoundCorrection; + + IndexSourceLink::SurfaceAccessor slSurfaceAccessor; + + KalmanFitterFunctionImpl(Fitter&& f, DirectFitter&& df, + const Acts::TrackingGeometry& trkGeo) + : fitter(std::move(f)), + directFitter(std::move(df)), + slSurfaceAccessor{trkGeo} {} + + template + auto makeKfOptions(const GeneralFitterOptions& options, + const calibrator_t& calibrator) const { + Acts::KalmanFitterExtensions extensions; + extensions.updater.connect< + &Acts::GainMatrixUpdater::operator()>( + &kfUpdater); + extensions.smoother.connect< + &Acts::GainMatrixSmoother::operator()>( + &kfSmoother); + extensions.reverseFilteringLogic + .connect<&SimpleReverseFilteringLogic::doBackwardFiltering>( + &reverseFilteringLogic); + + Acts::KalmanFitterOptions kfOptions( + options.geoContext, options.magFieldContext, options.calibrationContext, + extensions, options.propOptions, &(*options.referenceSurface)); + + kfOptions.referenceSurfaceStrategy = + Acts::KalmanFitterTargetSurfaceStrategy::first; + kfOptions.multipleScattering = multipleScattering; + kfOptions.energyLoss = energyLoss; + kfOptions.freeToBoundCorrection = freeToBoundCorrection; + kfOptions.extensions.calibrator.connect<&calibrator_t::calibrate>( + &calibrator); + + if (options.doRefit) { + kfOptions.extensions.surfaceAccessor.connect<&RefittingCalibrator::accessSurface>(); + } else { + kfOptions.extensions.surfaceAccessor.connect<&IndexSourceLink::SurfaceAccessor::operator()>( + &slSurfaceAccessor); + } + + return kfOptions; + } + + TrackFitterResult operator()(const std::vector& sourceLinks, + const TrackParameters& initialParameters, + const GeneralFitterOptions& options, + const MeasurementCalibratorAdapter& calibrator, + TrackContainer& tracks) const override { + const auto kfOptions = makeKfOptions(options, calibrator); + return fitter.fit(sourceLinks.begin(), sourceLinks.end(), initialParameters, + kfOptions, tracks); + } + + TrackFitterResult operator()( + const std::vector& sourceLinks, + const TrackParameters& initialParameters, + const GeneralFitterOptions& options, + const RefittingCalibrator& calibrator, + const std::vector& surfaceSequence, + TrackContainer& tracks) const override { + const auto kfOptions = makeKfOptions(options, calibrator); + return directFitter.fit(sourceLinks.begin(), sourceLinks.end(), + initialParameters, kfOptions, surfaceSequence, + tracks); + } +}; + +} // namespace + +std::shared_ptr +ActsExamples::makeKalmanFitterFunction( + std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + bool multipleScattering, bool energyLoss, + double reverseFilteringMomThreshold, + Acts::FreeToBoundCorrection freeToBoundCorrection, + const Acts::Logger& logger) { + // Stepper should be copied into the fitters + const Stepper stepper(std::move(magneticField)); + + // Standard fitter + const auto& geo = *trackingGeometry; + Acts::Navigator::Config cfg{std::move(trackingGeometry)}; + cfg.resolvePassive = false; + cfg.resolveMaterial = true; + cfg.resolveSensitive = true; + Acts::Navigator navigator(cfg, logger.cloneWithSuffix("Navigator")); + Propagator propagator(stepper, std::move(navigator), + logger.cloneWithSuffix("Propagator")); + Fitter trackFitter(std::move(propagator), logger.cloneWithSuffix("Fitter")); + + // Direct fitter + Acts::DirectNavigator directNavigator{ + logger.cloneWithSuffix("DirectNavigator")}; + DirectPropagator directPropagator(stepper, std::move(directNavigator), + logger.cloneWithSuffix("DirectPropagator")); + DirectFitter directTrackFitter(std::move(directPropagator), + logger.cloneWithSuffix("DirectFitter")); + + // build the fitter function. owns the fitter object. + auto fitterFunction = std::make_shared( + std::move(trackFitter), std::move(directTrackFitter), geo); + fitterFunction->multipleScattering = multipleScattering; + fitterFunction->energyLoss = energyLoss; + fitterFunction->reverseFilteringLogic.momentumThreshold = + reverseFilteringMomThreshold; + fitterFunction->freeToBoundCorrection = freeToBoundCorrection; + + return fitterFunction; +} diff --git a/source/tdis/tracking/KalmanFittingFactory.cpp b/source/tdis/tracking/KalmanFittingFactory.cpp index 45e05e7..58e779b 100644 --- a/source/tdis/tracking/KalmanFittingFactory.cpp +++ b/source/tdis/tracking/KalmanFittingFactory.cpp @@ -31,6 +31,7 @@ #include #include #include + #include //#include //#include @@ -78,6 +79,8 @@ void tdis::tracking::KalmanFittingFactory::Execute(int32_t runNumber, uint64_t e m_log->debug("{}::Execute", this->GetTypeName()); + std::shared_ptr fit; + } template diff --git a/source/tdis/tracking/RefittingCalibrator.cpp b/source/tdis/tracking/RefittingCalibrator.cpp new file mode 100644 index 0000000..36670a7 --- /dev/null +++ b/source/tdis/tracking/RefittingCalibrator.cpp @@ -0,0 +1,43 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "RefittingCalibrator.hpp" + +#include "Acts/Definitions/Algebra.hpp" +#include "Acts/EventData/MeasurementHelpers.hpp" +#include "Acts/EventData/SourceLink.hpp" +#include "Acts/Utilities/CalibrationContext.hpp" + +namespace ActsExamples { + +void RefittingCalibrator::calibrate(const Acts::GeometryContext& /*gctx*/, + const Acts::CalibrationContext& /*cctx*/, + const Acts::SourceLink& sourceLink, + Proxy trackState) const { + const auto sl = sourceLink.get(); + + // Reset the original uncalibrated source link on this track state + trackState.setUncalibratedSourceLink(sl.state.getUncalibratedSourceLink()); + + // Here we construct a measurement by extracting the information available + // in the state + Acts::visit_measurement(sl.state.calibratedSize(), [&](auto N) { + using namespace Acts; + constexpr int Size = decltype(N)::value; + + trackState.allocateCalibrated(Size); + trackState.template calibrated() = + sl.state.template calibrated(); + trackState.template calibratedCovariance() = + sl.state.template calibratedCovariance(); + }); + + trackState.setBoundSubspaceIndices(sl.state.boundSubspaceIndices()); +} + +} // namespace ActsExamples diff --git a/source/tdis/tracking/RefittingCalibrator.hpp b/source/tdis/tracking/RefittingCalibrator.hpp new file mode 100644 index 0000000..31c1965 --- /dev/null +++ b/source/tdis/tracking/RefittingCalibrator.hpp @@ -0,0 +1,45 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/EventData/MultiTrajectory.hpp" +#include "Acts/EventData/SourceLink.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/Geometry/GeometryContext.hpp" +#include "Acts/Geometry/GeometryIdentifier.hpp" +#include "Acts/Surfaces/Surface.hpp" +#include "Acts/Utilities/CalibrationContext.hpp" + +namespace Acts { +class ConstVectorMultiTrajectory; +class VectorMultiTrajectory; +} // namespace Acts + +namespace ActsExamples { + +struct RefittingCalibrator { + using Proxy = Acts::VectorMultiTrajectory::TrackStateProxy; + using ConstProxy = Acts::ConstVectorMultiTrajectory::ConstTrackStateProxy; + + struct RefittingSourceLink { + ConstProxy state; + }; + + static const Acts::Surface* accessSurface( + const Acts::SourceLink& sourceLink) { + const auto& refittingSl = sourceLink.get(); + return &refittingSl.state.referenceSurface(); + } + + void calibrate(const Acts::GeometryContext& gctx, + const Acts::CalibrationContext& cctx, + const Acts::SourceLink& sourceLink, Proxy trackState) const; +}; + +} // namespace ActsExamples diff --git a/source/tdis/tracking/TrackFitterFunction.hpp b/source/tdis/tracking/TrackFitterFunction.hpp new file mode 100644 index 0000000..34a83b3 --- /dev/null +++ b/source/tdis/tracking/TrackFitterFunction.hpp @@ -0,0 +1,119 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/EventData/SourceLink.hpp" +#include "Acts/EventData/TrackParameters.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/EventData/VectorTrackContainer.hpp" +#include "Acts/Geometry/GeometryContext.hpp" +#include "Acts/Geometry/TrackingGeometry.hpp" +#include "Acts/MagneticField/MagneticFieldContext.hpp" +#include "Acts/MagneticField/MagneticFieldProvider.hpp" +#include "Acts/Propagator/Propagator.hpp" +#include "Acts/TrackFitting/BetheHeitlerApprox.hpp" +#include "Acts/TrackFitting/GsfOptions.hpp" +#include "Acts/Utilities/CalibrationContext.hpp" +#include "ActsExamples/EventData/Measurement.hpp" +#include "ActsExamples/EventData/MeasurementCalibration.hpp" +#include "ActsExamples/EventData/Track.hpp" +#include "ActsExamples/TrackFitting/RefittingCalibrator.hpp" + +namespace ActsExamples { + +/// Fit function that takes the above parameters and runs a fit +/// @note This is separated into a virtual interface to keep compilation units +/// small. +class TrackFitterFunction { + public: + using TrackFitterResult = Acts::Result; + + struct GeneralFitterOptions { + std::reference_wrapper geoContext; + std::reference_wrapper magFieldContext; + std::reference_wrapper calibrationContext; + const Acts::Surface* referenceSurface = nullptr; + Acts::PropagatorPlainOptions propOptions; + bool doRefit = false; + }; + + virtual ~TrackFitterFunction() = default; + + virtual TrackFitterResult operator()(const std::vector&, + const TrackParameters&, + const GeneralFitterOptions&, + const MeasurementCalibratorAdapter&, + TrackContainer&) const = 0; + + virtual TrackFitterResult operator()(const std::vector&, + const TrackParameters&, + const GeneralFitterOptions&, + const RefittingCalibrator&, + const std::vector&, + TrackContainer&) const = 0; +}; + +/// Makes a fitter function object for the Kalman Filter +/// +std::shared_ptr makeKalmanFitterFunction( + std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + bool multipleScattering = true, bool energyLoss = true, + double reverseFilteringMomThreshold = 0.0, + Acts::FreeToBoundCorrection freeToBoundCorrection = Acts::FreeToBoundCorrection(), + const Acts::Logger& logger = *Acts::getDefaultLogger("Kalman", Acts::Logging::INFO)); + +/// This type is used in the Examples framework for the Bethe-Heitler +/// approximation +using BetheHeitlerApprox = Acts::AtlasBetheHeitlerApprox<6, 5>; + +/// Available algorithms for the mixture reduction +enum class MixtureReductionAlgorithm { weightCut, KLDistance }; + +/// Makes a fitter function object for the GSF +/// +/// @param trackingGeometry the trackingGeometry for the propagator +/// @param magneticField the magnetic field for the propagator +/// @param betheHeitlerApprox The object that encapsulates the approximation. +/// @param maxComponents number of maximum components in the track state +/// @param weightCutoff when to drop components +/// @param componentMergeMethod How to merge a mixture to a single set of +/// parameters and covariance +/// @param mixtureReductionAlgorithm How to reduce the number of components +/// in a mixture +/// @param logger a logger instance +std::shared_ptr makeGsfFitterFunction( + std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + BetheHeitlerApprox betheHeitlerApprox, std::size_t maxComponents, + double weightCutoff, Acts::ComponentMergeMethod componentMergeMethod, + MixtureReductionAlgorithm mixtureReductionAlgorithm, + const Acts::Logger& logger); + +/// Makes a fitter function object for the Global Chi Square Fitter (GX2F) +/// +/// @param trackingGeometry the trackingGeometry for the propagator +/// @param magneticField the magnetic field for the propagator +/// @param multipleScattering bool +/// @param energyLoss bool +/// @param freeToBoundCorrection bool +/// @param nUpdateMax max number of iterations during the fit +/// @param relChi2changeCutOff Check for convergence (abort condition). Set to 0 to skip. +/// @param logger a logger instance +std::shared_ptr makeGlobalChiSquareFitterFunction( + std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + bool multipleScattering = true, bool energyLoss = true, + Acts::FreeToBoundCorrection freeToBoundCorrection = + Acts::FreeToBoundCorrection(), + std::size_t nUpdateMax = 5, double relChi2changeCutOff = 1e-7, + const Acts::Logger& logger = *Acts::getDefaultLogger("Gx2f", + Acts::Logging::INFO)); + +} // namespace ActsExamples diff --git a/source/tdis/tracking/TrackFitting/CMakeLists.txt b/source/tdis/tracking/TrackFitting/CMakeLists.txt new file mode 100644 index 0000000..07e2609 --- /dev/null +++ b/source/tdis/tracking/TrackFitting/CMakeLists.txt @@ -0,0 +1,24 @@ +add_library( + ActsExamplesTrackFitting + SHARED + src/RefittingCalibrator.cpp + src/SurfaceSortingAlgorithm.cpp + src/TrackFittingAlgorithm.cpp + src/KalmanFitterFunction.cpp + src/RefittingAlgorithm.cpp + src/GsfFitterFunction.cpp + src/GlobalChiSquareFitterFunction.cpp +) +target_include_directories( + ActsExamplesTrackFitting + PUBLIC $ +) +target_link_libraries( + ActsExamplesTrackFitting + PUBLIC ActsCore ActsExamplesFramework ActsExamplesMagneticField +) + +install( + TARGETS ActsExamplesTrackFitting + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/RefittingAlgorithm.hpp b/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/RefittingAlgorithm.hpp new file mode 100644 index 0000000..19e08f7 --- /dev/null +++ b/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/RefittingAlgorithm.hpp @@ -0,0 +1,60 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/Track.hpp" +#include "ActsExamples/Framework/DataHandle.hpp" +#include "ActsExamples/Framework/IAlgorithm.hpp" +#include "ActsExamples/Framework/ProcessCode.hpp" +#include "ActsExamples/TrackFitting/TrackFitterFunction.hpp" + +#include +#include + +namespace ActsExamples { +class TrackFitterFunction; +struct AlgorithmContext; + +class RefittingAlgorithm final : public IAlgorithm { + public: + struct Config { + /// The input track collection + std::string inputTracks; + /// Output fitted tracks collection. + std::string outputTracks; + /// Type erased fitter function. + std::shared_ptr fit; + /// Pick a single track for debugging (-1 process all tracks) + int pickTrack = -1; + }; + + /// Constructor of the fitting algorithm + /// + /// @param config is the config struct to configure the algorithm + /// @param level is the logging level + RefittingAlgorithm(Config config, Acts::Logging::Level level); + + /// Framework execute method of the fitting algorithm + /// + /// @param ctx is the algorithm context that holds event-wise information + /// @return a process code to steer the algporithm flow + ActsExamples::ProcessCode execute(const AlgorithmContext& ctx) const final; + + /// Get readonly access to the config parameters + const Config& config() const { return m_cfg; } + + private: + Config m_cfg; + + ReadDataHandle m_inputTracks{this, "InputTracks"}; + WriteDataHandle m_outputTracks{this, "OutputTracks"}; +}; + +} // namespace ActsExamples diff --git a/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/RefittingCalibrator.hpp b/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/RefittingCalibrator.hpp new file mode 100644 index 0000000..31c1965 --- /dev/null +++ b/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/RefittingCalibrator.hpp @@ -0,0 +1,45 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/EventData/MultiTrajectory.hpp" +#include "Acts/EventData/SourceLink.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/Geometry/GeometryContext.hpp" +#include "Acts/Geometry/GeometryIdentifier.hpp" +#include "Acts/Surfaces/Surface.hpp" +#include "Acts/Utilities/CalibrationContext.hpp" + +namespace Acts { +class ConstVectorMultiTrajectory; +class VectorMultiTrajectory; +} // namespace Acts + +namespace ActsExamples { + +struct RefittingCalibrator { + using Proxy = Acts::VectorMultiTrajectory::TrackStateProxy; + using ConstProxy = Acts::ConstVectorMultiTrajectory::ConstTrackStateProxy; + + struct RefittingSourceLink { + ConstProxy state; + }; + + static const Acts::Surface* accessSurface( + const Acts::SourceLink& sourceLink) { + const auto& refittingSl = sourceLink.get(); + return &refittingSl.state.referenceSurface(); + } + + void calibrate(const Acts::GeometryContext& gctx, + const Acts::CalibrationContext& cctx, + const Acts::SourceLink& sourceLink, Proxy trackState) const; +}; + +} // namespace ActsExamples diff --git a/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/SurfaceSortingAlgorithm.hpp b/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/SurfaceSortingAlgorithm.hpp new file mode 100644 index 0000000..02f85bf --- /dev/null +++ b/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/SurfaceSortingAlgorithm.hpp @@ -0,0 +1,64 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/Index.hpp" +#include "ActsExamples/EventData/IndexSourceLink.hpp" +#include "ActsExamples/EventData/Measurement.hpp" +#include "ActsExamples/EventData/ProtoTrack.hpp" +#include "ActsExamples/EventData/SimHit.hpp" +#include "ActsExamples/EventData/Track.hpp" +#include "ActsExamples/Framework/DataHandle.hpp" +#include "ActsExamples/Framework/IAlgorithm.hpp" +#include "ActsExamples/Framework/ProcessCode.hpp" + +#include +#include +#include +#include + +namespace ActsExamples { +struct AlgorithmContext; + +using TrackHitList = std::map; + +class SurfaceSortingAlgorithm final : public IAlgorithm { + public: + struct Config { + /// Input proto track collection + std::string inputProtoTracks; + /// Input simulated hit collection + std::string inputSimHits; + /// Input measurement to simulated hit map for truth position + std::string inputMeasurementSimHitsMap; + /// Output proto track collection + std::string outputProtoTracks; + }; + + SurfaceSortingAlgorithm(Config cfg, Acts::Logging::Level level); + + ActsExamples::ProcessCode execute(const AlgorithmContext& ctx) const final; + + /// Get readonly access to the config parameters + const Config& config() const { return m_cfg; } + + private: + Config m_cfg; + + ReadDataHandle m_inputProtoTracks{this, + "InputProtoTracks"}; + ReadDataHandle m_inputSimHits{this, "InputSimHits"}; + ReadDataHandle m_inputMeasurementSimHitsMap{ + this, "InputMeasurementSimHitsMap"}; + WriteDataHandle m_outputProtoTracks{this, + "OutputProtoTracks"}; +}; + +} // namespace ActsExamples diff --git a/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/TrackFitterFunction.hpp b/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/TrackFitterFunction.hpp new file mode 100644 index 0000000..a58d0a9 --- /dev/null +++ b/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/TrackFitterFunction.hpp @@ -0,0 +1,121 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/EventData/SourceLink.hpp" +#include "Acts/EventData/TrackParameters.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/EventData/VectorTrackContainer.hpp" +#include "Acts/Geometry/GeometryContext.hpp" +#include "Acts/Geometry/TrackingGeometry.hpp" +#include "Acts/MagneticField/MagneticFieldContext.hpp" +#include "Acts/MagneticField/MagneticFieldProvider.hpp" +#include "Acts/Propagator/Propagator.hpp" +#include "Acts/TrackFitting/BetheHeitlerApprox.hpp" +#include "Acts/TrackFitting/GsfOptions.hpp" +#include "Acts/Utilities/CalibrationContext.hpp" +#include "ActsExamples/EventData/Measurement.hpp" +#include "ActsExamples/EventData/MeasurementCalibration.hpp" +#include "ActsExamples/EventData/Track.hpp" +#include "ActsExamples/TrackFitting/RefittingCalibrator.hpp" + +namespace ActsExamples { + +/// Fit function that takes the above parameters and runs a fit +/// @note This is separated into a virtual interface to keep compilation units +/// small. +class TrackFitterFunction { + public: + using TrackFitterResult = Acts::Result; + + struct GeneralFitterOptions { + std::reference_wrapper geoContext; + std::reference_wrapper magFieldContext; + std::reference_wrapper calibrationContext; + const Acts::Surface* referenceSurface = nullptr; + Acts::PropagatorPlainOptions propOptions; + bool doRefit = false; + }; + + virtual ~TrackFitterFunction() = default; + + virtual TrackFitterResult operator()(const std::vector&, + const TrackParameters&, + const GeneralFitterOptions&, + const MeasurementCalibratorAdapter&, + TrackContainer&) const = 0; + + virtual TrackFitterResult operator()(const std::vector&, + const TrackParameters&, + const GeneralFitterOptions&, + const RefittingCalibrator&, + const std::vector&, + TrackContainer&) const = 0; +}; + +/// Makes a fitter function object for the Kalman Filter +/// +std::shared_ptr makeKalmanFitterFunction( + std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + bool multipleScattering = true, bool energyLoss = true, + double reverseFilteringMomThreshold = 0.0, + Acts::FreeToBoundCorrection freeToBoundCorrection = + Acts::FreeToBoundCorrection(), + const Acts::Logger& logger = *Acts::getDefaultLogger("Kalman", + Acts::Logging::INFO)); + +/// This type is used in the Examples framework for the Bethe-Heitler +/// approximation +using BetheHeitlerApprox = Acts::AtlasBetheHeitlerApprox<6, 5>; + +/// Available algorithms for the mixture reduction +enum class MixtureReductionAlgorithm { weightCut, KLDistance }; + +/// Makes a fitter function object for the GSF +/// +/// @param trackingGeometry the trackingGeometry for the propagator +/// @param magneticField the magnetic field for the propagator +/// @param betheHeitlerApprox The object that encapsulates the approximation. +/// @param maxComponents number of maximum components in the track state +/// @param weightCutoff when to drop components +/// @param componentMergeMethod How to merge a mixture to a single set of +/// parameters and covariance +/// @param mixtureReductionAlgorithm How to reduce the number of components +/// in a mixture +/// @param logger a logger instance +std::shared_ptr makeGsfFitterFunction( + std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + BetheHeitlerApprox betheHeitlerApprox, std::size_t maxComponents, + double weightCutoff, Acts::ComponentMergeMethod componentMergeMethod, + MixtureReductionAlgorithm mixtureReductionAlgorithm, + const Acts::Logger& logger); + +/// Makes a fitter function object for the Global Chi Square Fitter (GX2F) +/// +/// @param trackingGeometry the trackingGeometry for the propagator +/// @param magneticField the magnetic field for the propagator +/// @param multipleScattering bool +/// @param energyLoss bool +/// @param freeToBoundCorrection bool +/// @param nUpdateMax max number of iterations during the fit +/// @param relChi2changeCutOff Check for convergence (abort condition). Set to 0 to skip. +/// @param logger a logger instance +std::shared_ptr makeGlobalChiSquareFitterFunction( + std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + bool multipleScattering = true, bool energyLoss = true, + Acts::FreeToBoundCorrection freeToBoundCorrection = + Acts::FreeToBoundCorrection(), + std::size_t nUpdateMax = 5, double relChi2changeCutOff = 1e-7, + const Acts::Logger& logger = *Acts::getDefaultLogger("Gx2f", + Acts::Logging::INFO)); + +} // namespace ActsExamples diff --git a/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp b/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp new file mode 100644 index 0000000..934203a --- /dev/null +++ b/source/tdis/tracking/TrackFitting/include/ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp @@ -0,0 +1,85 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#pragma once + +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/Cluster.hpp" +#include "ActsExamples/EventData/IndexSourceLink.hpp" +#include "ActsExamples/EventData/Measurement.hpp" +#include "ActsExamples/EventData/ProtoTrack.hpp" +#include "ActsExamples/EventData/Track.hpp" +#include "ActsExamples/Framework/DataHandle.hpp" +#include "ActsExamples/Framework/IAlgorithm.hpp" +#include "ActsExamples/Framework/ProcessCode.hpp" +#include "ActsExamples/TrackFitting/TrackFitterFunction.hpp" + +#include +#include + +namespace Acts { +class TrackingGeometry; +} + +namespace ActsExamples { +class MeasurementCalibrator; +class TrackFitterFunction; +struct AlgorithmContext; + +class TrackFittingAlgorithm final : public IAlgorithm { + public: + struct Config { + /// Input measurements collection. + std::string inputMeasurements; + /// Input proto tracks collection, i.e. groups of hit indices. + std::string inputProtoTracks; + /// Input initial track parameter estimates for for each proto track. + std::string inputInitialTrackParameters; + /// (optional) Input clusters for each measurement + std::string inputClusters; + /// Output fitted tracks collection. + std::string outputTracks; + /// Type erased fitter function. + std::shared_ptr fit; + /// Pick a single track for debugging (-1 process all tracks) + int pickTrack = -1; + // Type erased calibrator for the measurements + std::shared_ptr calibrator; + }; + + /// Constructor of the fitting algorithm + /// + /// @param config is the config struct to configure the algorithm + /// @param level is the logging level + TrackFittingAlgorithm(Config config, Acts::Logging::Level level); + + /// Framework execute method of the fitting algorithm + /// + /// @param ctx is the algorithm context that holds event-wise information + /// @return a process code to steer the algporithm flow + ActsExamples::ProcessCode execute(const AlgorithmContext& ctx) const final; + + /// Get readonly access to the config parameters + const Config& config() const { return m_cfg; } + + private: + Config m_cfg; + + ReadDataHandle m_inputMeasurements{this, + "InputMeasurements"}; + ReadDataHandle m_inputProtoTracks{this, + "InputProtoTracks"}; + ReadDataHandle m_inputInitialTrackParameters{ + this, "InputInitialTrackParameters"}; + + ReadDataHandle m_inputClusters{this, "InputClusters"}; + + WriteDataHandle m_outputTracks{this, "OutputTracks"}; +}; + +} // namespace ActsExamples diff --git a/source/tdis/tracking/TrackFitting/src/GlobalChiSquareFitterFunction.cpp b/source/tdis/tracking/TrackFitting/src/GlobalChiSquareFitterFunction.cpp new file mode 100644 index 0000000..4f6f6a9 --- /dev/null +++ b/source/tdis/tracking/TrackFitting/src/GlobalChiSquareFitterFunction.cpp @@ -0,0 +1,171 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +// TODO We still use some Kalman Fitter functionalities. Check for replacement + +#include "Acts/Definitions/Direction.hpp" +#include "Acts/Definitions/TrackParametrization.hpp" +#include "Acts/EventData/MultiTrajectory.hpp" +#include "Acts/EventData/TrackContainer.hpp" +#include "Acts/EventData/TrackStatePropMask.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/EventData/VectorTrackContainer.hpp" +#include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp" +#include "Acts/Geometry/GeometryIdentifier.hpp" +#include "Acts/Propagator/DirectNavigator.hpp" +#include "Acts/Propagator/Navigator.hpp" +#include "Acts/Propagator/Propagator.hpp" +#include "Acts/Propagator/SympyStepper.hpp" +#include "Acts/TrackFitting/GlobalChiSquareFitter.hpp" +#include "Acts/TrackFitting/KalmanFitter.hpp" +#include "Acts/Utilities/Delegate.hpp" +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/IndexSourceLink.hpp" +#include "ActsExamples/EventData/MeasurementCalibration.hpp" +#include "ActsExamples/EventData/Track.hpp" +#include "ActsExamples/TrackFitting/RefittingCalibrator.hpp" +#include "ActsExamples/TrackFitting/TrackFitterFunction.hpp" + +#include +#include +#include +#include +#include +#include + +namespace Acts { +class MagneticFieldProvider; +class SourceLink; +class Surface; +class TrackingGeometry; +} // namespace Acts + +namespace { + +using Stepper = Acts::SympyStepper; +using Propagator = Acts::Propagator; +using Fitter = + Acts::Experimental::Gx2Fitter; +using DirectPropagator = Acts::Propagator; +using DirectFitter = + Acts::KalmanFitter; + +using TrackContainer = + Acts::TrackContainer; + +using namespace ActsExamples; + +struct GlobalChiSquareFitterFunctionImpl final : public TrackFitterFunction { + Fitter fitter; + DirectFitter directFitter; + + bool multipleScattering = false; + bool energyLoss = false; + Acts::FreeToBoundCorrection freeToBoundCorrection; + std::size_t nUpdateMax = 5; + double relChi2changeCutOff = 1e-7; + + IndexSourceLink::SurfaceAccessor m_slSurfaceAccessor; + + GlobalChiSquareFitterFunctionImpl(Fitter&& f, DirectFitter&& df, + const Acts::TrackingGeometry& trkGeo) + : fitter(std::move(f)), + directFitter(std::move(df)), + m_slSurfaceAccessor{trkGeo} {} + + template + auto makeGx2fOptions(const GeneralFitterOptions& options, + const calibrator_t& calibrator) const { + Acts::Experimental::Gx2FitterExtensions + extensions; + extensions.calibrator.connect<&calibrator_t::calibrate>(&calibrator); + + if (options.doRefit) { + extensions.surfaceAccessor.connect<&RefittingCalibrator::accessSurface>(); + } else { + extensions.surfaceAccessor + .connect<&IndexSourceLink::SurfaceAccessor::operator()>( + &m_slSurfaceAccessor); + } + + const Acts::Experimental::Gx2FitterOptions gx2fOptions( + options.geoContext, options.magFieldContext, options.calibrationContext, + extensions, options.propOptions, &(*options.referenceSurface), + multipleScattering, energyLoss, freeToBoundCorrection, nUpdateMax, + relChi2changeCutOff); + + return gx2fOptions; + } + + TrackFitterResult operator()(const std::vector& sourceLinks, + const TrackParameters& initialParameters, + const GeneralFitterOptions& options, + const MeasurementCalibratorAdapter& calibrator, + TrackContainer& tracks) const override { + const auto gx2fOptions = makeGx2fOptions(options, calibrator); + return fitter.fit(sourceLinks.begin(), sourceLinks.end(), initialParameters, + gx2fOptions, tracks); + } + + // We need a placeholder for the directNavigator overload. Otherwise, we would + // have an unimplemented pure virtual method in a final class. + TrackFitterResult operator()( + const std::vector& /*sourceLinks*/, + const TrackParameters& /*initialParameters*/, + const GeneralFitterOptions& /*options*/, + const RefittingCalibrator& /*calibrator*/, + const std::vector& /*surfaceSequence*/, + TrackContainer& /*tracks*/) const override { + throw std::runtime_error( + "direct navigation with GX2 fitter is not implemented"); + } +}; + +} // namespace + +std::shared_ptr +ActsExamples::makeGlobalChiSquareFitterFunction( + std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + bool multipleScattering, bool energyLoss, + Acts::FreeToBoundCorrection freeToBoundCorrection, std::size_t nUpdateMax, + double relChi2changeCutOff, const Acts::Logger& logger) { + // Stepper should be copied into the fitters + const Stepper stepper(std::move(magneticField)); + + // Standard fitter + const auto& geo = *trackingGeometry; + Acts::Navigator::Config cfg{std::move(trackingGeometry)}; + cfg.resolvePassive = false; + cfg.resolveMaterial = true; + cfg.resolveSensitive = true; + Acts::Navigator navigator(cfg, logger.cloneWithSuffix("Navigator")); + Propagator propagator(stepper, std::move(navigator), + logger.cloneWithSuffix("Propagator")); + Fitter trackFitter(std::move(propagator), logger.cloneWithSuffix("Fitter")); + + // Direct fitter + Acts::DirectNavigator directNavigator{ + logger.cloneWithSuffix("DirectNavigator")}; + DirectPropagator directPropagator(stepper, std::move(directNavigator), + logger.cloneWithSuffix("DirectPropagator")); + DirectFitter directTrackFitter(std::move(directPropagator), + logger.cloneWithSuffix("DirectFitter")); + + // build the fitter function. owns the fitter object. + auto fitterFunction = std::make_shared( + std::move(trackFitter), std::move(directTrackFitter), geo); + fitterFunction->multipleScattering = multipleScattering; + fitterFunction->energyLoss = energyLoss; + fitterFunction->freeToBoundCorrection = freeToBoundCorrection; + fitterFunction->nUpdateMax = nUpdateMax; + fitterFunction->relChi2changeCutOff = relChi2changeCutOff; + + return fitterFunction; +} diff --git a/source/tdis/tracking/TrackFitting/src/GsfFitterFunction.cpp b/source/tdis/tracking/TrackFitting/src/GsfFitterFunction.cpp new file mode 100644 index 0000000..7e893ff --- /dev/null +++ b/source/tdis/tracking/TrackFitting/src/GsfFitterFunction.cpp @@ -0,0 +1,234 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "Acts/Definitions/Common.hpp" +#include "Acts/Definitions/Direction.hpp" +#include "Acts/EventData/MultiTrajectory.hpp" +#include "Acts/EventData/TrackContainer.hpp" +#include "Acts/EventData/TrackParameters.hpp" +#include "Acts/EventData/TrackStatePropMask.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/EventData/VectorTrackContainer.hpp" +#include "Acts/Geometry/GeometryIdentifier.hpp" +#include "Acts/Propagator/DirectNavigator.hpp" +#include "Acts/Propagator/MultiEigenStepperLoop.hpp" +#include "Acts/Propagator/Navigator.hpp" +#include "Acts/Propagator/Propagator.hpp" +#include "Acts/TrackFitting/GainMatrixUpdater.hpp" +#include "Acts/TrackFitting/GaussianSumFitter.hpp" +#include "Acts/TrackFitting/GsfMixtureReduction.hpp" +#include "Acts/TrackFitting/GsfOptions.hpp" +#include "Acts/Utilities/Delegate.hpp" +#include "Acts/Utilities/HashedString.hpp" +#include "Acts/Utilities/Intersection.hpp" +#include "Acts/Utilities/Logger.hpp" +#include "Acts/Utilities/Zip.hpp" +#include "ActsExamples/EventData/IndexSourceLink.hpp" +#include "ActsExamples/EventData/MeasurementCalibration.hpp" +#include "ActsExamples/EventData/Track.hpp" +#include "ActsExamples/TrackFitting/RefittingCalibrator.hpp" +#include "ActsExamples/TrackFitting/TrackFitterFunction.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace Acts { +class MagneticFieldProvider; +class SourceLink; +class Surface; +class TrackingGeometry; +} // namespace Acts + +using namespace ActsExamples; + +namespace { + +using MultiStepper = + Acts::MultiEigenStepperLoop; +using Propagator = Acts::Propagator; +using DirectPropagator = Acts::Propagator; + +using Fitter = Acts::GaussianSumFitter; +using DirectFitter = + Acts::GaussianSumFitter; +using TrackContainer = + Acts::TrackContainer; + +struct GsfFitterFunctionImpl final : public ActsExamples::TrackFitterFunction { + Fitter fitter; + DirectFitter directFitter; + + Acts::GainMatrixUpdater updater; + + std::size_t maxComponents = 0; + double weightCutoff = 0; + const double momentumCutoff = 0; // 500_MeV; + bool abortOnError = false; + bool disableAllMaterialHandling = false; + MixtureReductionAlgorithm reductionAlg = + MixtureReductionAlgorithm::KLDistance; + Acts::ComponentMergeMethod mergeMethod = + Acts::ComponentMergeMethod::eMaxWeight; + + IndexSourceLink::SurfaceAccessor m_slSurfaceAccessor; + + GsfFitterFunctionImpl(Fitter&& f, DirectFitter&& df, + const Acts::TrackingGeometry& trkGeo) + : fitter(std::move(f)), + directFitter(std::move(df)), + m_slSurfaceAccessor{trkGeo} {} + + template + auto makeGsfOptions(const GeneralFitterOptions& options, + const calibrator_t& calibrator) const { + Acts::GsfExtensions extensions; + extensions.updater.connect< + &Acts::GainMatrixUpdater::operator()>( + &updater); + + Acts::GsfOptions gsfOptions{ + options.geoContext, options.magFieldContext, + options.calibrationContext}; + gsfOptions.extensions = extensions; + gsfOptions.propagatorPlainOptions = options.propOptions; + gsfOptions.referenceSurface = options.referenceSurface; + gsfOptions.maxComponents = maxComponents; + gsfOptions.weightCutoff = weightCutoff; + gsfOptions.abortOnError = abortOnError; + gsfOptions.disableAllMaterialHandling = disableAllMaterialHandling; + gsfOptions.componentMergeMethod = mergeMethod; + + gsfOptions.extensions.calibrator.connect<&calibrator_t::calibrate>( + &calibrator); + + if (options.doRefit) { + gsfOptions.extensions.surfaceAccessor + .connect<&RefittingCalibrator::accessSurface>(); + } else { + gsfOptions.extensions.surfaceAccessor + .connect<&IndexSourceLink::SurfaceAccessor::operator()>( + &m_slSurfaceAccessor); + } + switch (reductionAlg) { + case MixtureReductionAlgorithm::weightCut: { + gsfOptions.extensions.mixtureReducer + .connect<&Acts::reduceMixtureLargestWeights>(); + } break; + case MixtureReductionAlgorithm::KLDistance: { + gsfOptions.extensions.mixtureReducer + .connect<&Acts::reduceMixtureWithKLDistance>(); + } break; + } + + return gsfOptions; + } + + TrackFitterResult operator()(const std::vector& sourceLinks, + const TrackParameters& initialParameters, + const GeneralFitterOptions& options, + const MeasurementCalibratorAdapter& calibrator, + TrackContainer& tracks) const override { + const auto gsfOptions = makeGsfOptions(options, calibrator); + + using namespace Acts::GsfConstants; + if (!tracks.hasColumn(Acts::hashString(kFinalMultiComponentStateColumn))) { + std::string key(kFinalMultiComponentStateColumn); + tracks.template addColumn(key); + } + + if (!tracks.hasColumn(Acts::hashString(kFwdMaxMaterialXOverX0))) { + tracks.template addColumn(std::string(kFwdMaxMaterialXOverX0)); + } + + if (!tracks.hasColumn(Acts::hashString(kFwdSumMaterialXOverX0))) { + tracks.template addColumn(std::string(kFwdSumMaterialXOverX0)); + } + + return fitter.fit(sourceLinks.begin(), sourceLinks.end(), initialParameters, + gsfOptions, tracks); + } + + TrackFitterResult operator()( + const std::vector& sourceLinks, + const TrackParameters& initialParameters, + const GeneralFitterOptions& options, + const RefittingCalibrator& calibrator, + const std::vector& surfaceSequence, + TrackContainer& tracks) const override { + const auto gsfOptions = makeGsfOptions(options, calibrator); + + using namespace Acts::GsfConstants; + if (!tracks.hasColumn(Acts::hashString(kFinalMultiComponentStateColumn))) { + std::string key(kFinalMultiComponentStateColumn); + tracks.template addColumn(key); + } + + return directFitter.fit(sourceLinks.begin(), sourceLinks.end(), + initialParameters, gsfOptions, surfaceSequence, + tracks); + } +}; + +} // namespace + +std::shared_ptr ActsExamples::makeGsfFitterFunction( + std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + BetheHeitlerApprox betheHeitlerApprox, std::size_t maxComponents, + double weightCutoff, Acts::ComponentMergeMethod componentMergeMethod, + MixtureReductionAlgorithm mixtureReductionAlgorithm, + const Acts::Logger& logger) { + // Standard fitter + MultiStepper stepper(magneticField, logger.cloneWithSuffix("Step")); + const auto& geo = *trackingGeometry; + Acts::Navigator::Config cfg{std::move(trackingGeometry)}; + cfg.resolvePassive = false; + cfg.resolveMaterial = true; + cfg.resolveSensitive = true; + Acts::Navigator navigator(cfg, logger.cloneWithSuffix("Navigator")); + Propagator propagator(std::move(stepper), std::move(navigator), + logger.cloneWithSuffix("Propagator")); + Fitter trackFitter(std::move(propagator), + BetheHeitlerApprox(betheHeitlerApprox), + logger.cloneWithSuffix("GSF")); + + // Direct fitter + MultiStepper directStepper(std::move(magneticField), + logger.cloneWithSuffix("Step")); + Acts::DirectNavigator directNavigator{ + logger.cloneWithSuffix("DirectNavigator")}; + DirectPropagator directPropagator(std::move(directStepper), + std::move(directNavigator), + logger.cloneWithSuffix("DirectPropagator")); + DirectFitter directTrackFitter(std::move(directPropagator), + BetheHeitlerApprox(betheHeitlerApprox), + logger.cloneWithSuffix("DirectGSF")); + + // build the fitter functions. owns the fitter object. + auto fitterFunction = std::make_shared( + std::move(trackFitter), std::move(directTrackFitter), geo); + fitterFunction->maxComponents = maxComponents; + fitterFunction->weightCutoff = weightCutoff; + fitterFunction->mergeMethod = componentMergeMethod; + fitterFunction->reductionAlg = mixtureReductionAlgorithm; + + return fitterFunction; +} diff --git a/source/tdis/tracking/TrackFitting/src/KalmanFitterFunction.cpp b/source/tdis/tracking/TrackFitting/src/KalmanFitterFunction.cpp new file mode 100644 index 0000000..6f835c5 --- /dev/null +++ b/source/tdis/tracking/TrackFitting/src/KalmanFitterFunction.cpp @@ -0,0 +1,196 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "Acts/Definitions/Direction.hpp" +#include "Acts/Definitions/TrackParametrization.hpp" +#include "Acts/EventData/MultiTrajectory.hpp" +#include "Acts/EventData/TrackContainer.hpp" +#include "Acts/EventData/TrackStatePropMask.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/EventData/VectorTrackContainer.hpp" +#include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp" +#include "Acts/Geometry/GeometryIdentifier.hpp" +#include "Acts/Propagator/DirectNavigator.hpp" +#include "Acts/Propagator/Navigator.hpp" +#include "Acts/Propagator/Propagator.hpp" +#include "Acts/Propagator/SympyStepper.hpp" +#include "Acts/TrackFitting/GainMatrixSmoother.hpp" +#include "Acts/TrackFitting/GainMatrixUpdater.hpp" +#include "Acts/TrackFitting/KalmanFitter.hpp" +#include "Acts/Utilities/Delegate.hpp" +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/IndexSourceLink.hpp" +#include "ActsExamples/EventData/MeasurementCalibration.hpp" +#include "ActsExamples/EventData/Track.hpp" +#include "ActsExamples/TrackFitting/RefittingCalibrator.hpp" +#include "ActsExamples/TrackFitting/TrackFitterFunction.hpp" + +#include +#include +#include +#include +#include +#include + +namespace Acts { +class MagneticFieldProvider; +class SourceLink; +class Surface; +class TrackingGeometry; +} // namespace Acts + +namespace { + +using Stepper = Acts::SympyStepper; +using Propagator = Acts::Propagator; +using Fitter = Acts::KalmanFitter; +using DirectPropagator = Acts::Propagator; +using DirectFitter = + Acts::KalmanFitter; + +using TrackContainer = + Acts::TrackContainer; + +struct SimpleReverseFilteringLogic { + double momentumThreshold = 0; + + bool doBackwardFiltering( + Acts::VectorMultiTrajectory::ConstTrackStateProxy trackState) const { + auto momentum = std::abs(1 / trackState.filtered()[Acts::eBoundQOverP]); + return (momentum <= momentumThreshold); + } +}; + +using namespace ActsExamples; + +struct KalmanFitterFunctionImpl final : public TrackFitterFunction { + Fitter fitter; + DirectFitter directFitter; + + Acts::GainMatrixUpdater kfUpdater; + Acts::GainMatrixSmoother kfSmoother; + SimpleReverseFilteringLogic reverseFilteringLogic; + + bool multipleScattering = false; + bool energyLoss = false; + Acts::FreeToBoundCorrection freeToBoundCorrection; + + IndexSourceLink::SurfaceAccessor slSurfaceAccessor; + + KalmanFitterFunctionImpl(Fitter&& f, DirectFitter&& df, + const Acts::TrackingGeometry& trkGeo) + : fitter(std::move(f)), + directFitter(std::move(df)), + slSurfaceAccessor{trkGeo} {} + + template + auto makeKfOptions(const GeneralFitterOptions& options, + const calibrator_t& calibrator) const { + Acts::KalmanFitterExtensions extensions; + extensions.updater.connect< + &Acts::GainMatrixUpdater::operator()>( + &kfUpdater); + extensions.smoother.connect< + &Acts::GainMatrixSmoother::operator()>( + &kfSmoother); + extensions.reverseFilteringLogic + .connect<&SimpleReverseFilteringLogic::doBackwardFiltering>( + &reverseFilteringLogic); + + Acts::KalmanFitterOptions kfOptions( + options.geoContext, options.magFieldContext, options.calibrationContext, + extensions, options.propOptions, &(*options.referenceSurface)); + + kfOptions.referenceSurfaceStrategy = + Acts::KalmanFitterTargetSurfaceStrategy::first; + kfOptions.multipleScattering = multipleScattering; + kfOptions.energyLoss = energyLoss; + kfOptions.freeToBoundCorrection = freeToBoundCorrection; + kfOptions.extensions.calibrator.connect<&calibrator_t::calibrate>( + &calibrator); + + if (options.doRefit) { + kfOptions.extensions.surfaceAccessor + .connect<&RefittingCalibrator::accessSurface>(); + } else { + kfOptions.extensions.surfaceAccessor + .connect<&IndexSourceLink::SurfaceAccessor::operator()>( + &slSurfaceAccessor); + } + + return kfOptions; + } + + TrackFitterResult operator()(const std::vector& sourceLinks, + const TrackParameters& initialParameters, + const GeneralFitterOptions& options, + const MeasurementCalibratorAdapter& calibrator, + TrackContainer& tracks) const override { + const auto kfOptions = makeKfOptions(options, calibrator); + return fitter.fit(sourceLinks.begin(), sourceLinks.end(), initialParameters, + kfOptions, tracks); + } + + TrackFitterResult operator()( + const std::vector& sourceLinks, + const TrackParameters& initialParameters, + const GeneralFitterOptions& options, + const RefittingCalibrator& calibrator, + const std::vector& surfaceSequence, + TrackContainer& tracks) const override { + const auto kfOptions = makeKfOptions(options, calibrator); + return directFitter.fit(sourceLinks.begin(), sourceLinks.end(), + initialParameters, kfOptions, surfaceSequence, + tracks); + } +}; + +} // namespace + +std::shared_ptr +ActsExamples::makeKalmanFitterFunction( + std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + bool multipleScattering, bool energyLoss, + double reverseFilteringMomThreshold, + Acts::FreeToBoundCorrection freeToBoundCorrection, + const Acts::Logger& logger) { + // Stepper should be copied into the fitters + const Stepper stepper(std::move(magneticField)); + + // Standard fitter + const auto& geo = *trackingGeometry; + Acts::Navigator::Config cfg{std::move(trackingGeometry)}; + cfg.resolvePassive = false; + cfg.resolveMaterial = true; + cfg.resolveSensitive = true; + Acts::Navigator navigator(cfg, logger.cloneWithSuffix("Navigator")); + Propagator propagator(stepper, std::move(navigator), + logger.cloneWithSuffix("Propagator")); + Fitter trackFitter(std::move(propagator), logger.cloneWithSuffix("Fitter")); + + // Direct fitter + Acts::DirectNavigator directNavigator{ + logger.cloneWithSuffix("DirectNavigator")}; + DirectPropagator directPropagator(stepper, std::move(directNavigator), + logger.cloneWithSuffix("DirectPropagator")); + DirectFitter directTrackFitter(std::move(directPropagator), + logger.cloneWithSuffix("DirectFitter")); + + // build the fitter function. owns the fitter object. + auto fitterFunction = std::make_shared( + std::move(trackFitter), std::move(directTrackFitter), geo); + fitterFunction->multipleScattering = multipleScattering; + fitterFunction->energyLoss = energyLoss; + fitterFunction->reverseFilteringLogic.momentumThreshold = + reverseFilteringMomThreshold; + fitterFunction->freeToBoundCorrection = freeToBoundCorrection; + + return fitterFunction; +} diff --git a/source/tdis/tracking/TrackFitting/src/RefittingAlgorithm.cpp b/source/tdis/tracking/TrackFitting/src/RefittingAlgorithm.cpp new file mode 100644 index 0000000..03aa74f --- /dev/null +++ b/source/tdis/tracking/TrackFitting/src/RefittingAlgorithm.cpp @@ -0,0 +1,143 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "ActsExamples/TrackFitting/RefittingAlgorithm.hpp" + +#include "Acts/Definitions/Algebra.hpp" +#include "Acts/EventData/GenericBoundTrackParameters.hpp" +#include "Acts/EventData/MultiTrajectory.hpp" +#include "Acts/EventData/SourceLink.hpp" +#include "Acts/EventData/TrackContainer.hpp" +#include "Acts/EventData/TrackParameters.hpp" +#include "Acts/EventData/TrackProxy.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/EventData/VectorTrackContainer.hpp" +#include "Acts/Propagator/Propagator.hpp" +#include "Acts/Surfaces/Surface.hpp" +#include "Acts/Utilities/Result.hpp" +#include "ActsExamples/Framework/AlgorithmContext.hpp" +#include "ActsExamples/TrackFitting/RefittingCalibrator.hpp" +#include "ActsExamples/TrackFitting/TrackFitterFunction.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +ActsExamples::RefittingAlgorithm::RefittingAlgorithm(Config config, + Acts::Logging::Level level) + : ActsExamples::IAlgorithm("TrackFittingAlgorithm", level), + m_cfg(std::move(config)) { + if (m_cfg.inputTracks.empty()) { + throw std::invalid_argument("Missing input tracks collection"); + } + if (m_cfg.outputTracks.empty()) { + throw std::invalid_argument("Missing output tracks collection"); + } + + m_inputTracks.initialize(m_cfg.inputTracks); + m_outputTracks.initialize(m_cfg.outputTracks); +} + +ActsExamples::ProcessCode ActsExamples::RefittingAlgorithm::execute( + const ActsExamples::AlgorithmContext& ctx) const { + const auto& inputTracks = m_inputTracks(ctx); + + auto trackContainer = std::make_shared(); + auto trackStateContainer = std::make_shared(); + TrackContainer tracks(trackContainer, trackStateContainer); + + // Perform the fit for each input track + std::vector trackSourceLinks; + std::vector surfSequence; + RefittingCalibrator calibrator; + + auto itrack = 0ul; + for (const auto& track : inputTracks) { + // Check if you are not in picking mode + if (m_cfg.pickTrack > -1 && m_cfg.pickTrack != static_cast(itrack++)) { + continue; + } + + if (!track.hasReferenceSurface()) { + ACTS_VERBOSE("Skip track " << itrack << ": missing ref surface"); + continue; + } + + TrackFitterFunction::GeneralFitterOptions options{ + ctx.geoContext, ctx.magFieldContext, ctx.calibContext, + &track.referenceSurface(), + Acts::PropagatorPlainOptions(ctx.geoContext, ctx.magFieldContext)}; + options.doRefit = true; + + const Acts::BoundTrackParameters initialParams( + track.referenceSurface().getSharedPtr(), track.parameters(), + track.covariance(), track.particleHypothesis()); + + trackSourceLinks.clear(); + surfSequence.clear(); + + for (auto state : track.trackStatesReversed()) { + surfSequence.push_back(&state.referenceSurface()); + + if (!state.hasCalibrated()) { + continue; + } + + auto sl = RefittingCalibrator::RefittingSourceLink{state}; + trackSourceLinks.push_back(Acts::SourceLink{sl}); + } + + if (surfSequence.empty()) { + ACTS_WARNING("Empty track " << itrack << " found."); + continue; + } + + std::ranges::reverse(surfSequence); + + ACTS_VERBOSE("Initial parameters: " + << initialParams.fourPosition(ctx.geoContext).transpose() + << " -> " << initialParams.direction().transpose()); + + ACTS_DEBUG("Invoke direct fitter for track " << itrack); + auto result = (*m_cfg.fit)(trackSourceLinks, initialParams, options, + calibrator, surfSequence, tracks); + + if (result.ok()) { + // Get the fit output object + const auto& refittedTrack = result.value(); + if (refittedTrack.hasReferenceSurface()) { + ACTS_VERBOSE("Refitted parameters for track " << itrack); + ACTS_VERBOSE(" " << track.parameters().transpose()); + } else { + ACTS_DEBUG("No refitted parameters for track " << itrack); + } + } else { + ACTS_WARNING("Fit failed for track " + << itrack << " with error: " << result.error() << ", " + << result.error().message()); + } + } + + std::stringstream ss; + trackStateContainer->statistics().toStream(ss); + ACTS_DEBUG(ss.str()); + + ConstTrackContainer constTracks{ + std::make_shared( + std::move(*trackContainer)), + std::make_shared( + std::move(*trackStateContainer))}; + + m_outputTracks(ctx, std::move(constTracks)); + return ActsExamples::ProcessCode::SUCCESS; +} diff --git a/source/tdis/tracking/TrackFitting/src/RefittingCalibrator.cpp b/source/tdis/tracking/TrackFitting/src/RefittingCalibrator.cpp new file mode 100644 index 0000000..b54efd0 --- /dev/null +++ b/source/tdis/tracking/TrackFitting/src/RefittingCalibrator.cpp @@ -0,0 +1,43 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "ActsExamples/TrackFitting/RefittingCalibrator.hpp" + +#include "Acts/Definitions/Algebra.hpp" +#include "Acts/EventData/MeasurementHelpers.hpp" +#include "Acts/EventData/SourceLink.hpp" +#include "Acts/Utilities/CalibrationContext.hpp" + +namespace ActsExamples { + +void RefittingCalibrator::calibrate(const Acts::GeometryContext& /*gctx*/, + const Acts::CalibrationContext& /*cctx*/, + const Acts::SourceLink& sourceLink, + Proxy trackState) const { + const auto sl = sourceLink.get(); + + // Reset the original uncalibrated source link on this track state + trackState.setUncalibratedSourceLink(sl.state.getUncalibratedSourceLink()); + + // Here we construct a measurement by extracting the information available + // in the state + Acts::visit_measurement(sl.state.calibratedSize(), [&](auto N) { + using namespace Acts; + constexpr int Size = decltype(N)::value; + + trackState.allocateCalibrated(Size); + trackState.template calibrated() = + sl.state.template calibrated(); + trackState.template calibratedCovariance() = + sl.state.template calibratedCovariance(); + }); + + trackState.setBoundSubspaceIndices(sl.state.boundSubspaceIndices()); +} + +} // namespace ActsExamples diff --git a/source/tdis/tracking/TrackFitting/src/SurfaceSortingAlgorithm.cpp b/source/tdis/tracking/TrackFitting/src/SurfaceSortingAlgorithm.cpp new file mode 100644 index 0000000..040634e --- /dev/null +++ b/source/tdis/tracking/TrackFitting/src/SurfaceSortingAlgorithm.cpp @@ -0,0 +1,87 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "ActsExamples/TrackFitting/SurfaceSortingAlgorithm.hpp" + +#include "ActsExamples/EventData/ProtoTrack.hpp" +#include "ActsExamples/EventData/SimHit.hpp" +#include "ActsFatras/EventData/Hit.hpp" + +#include +#include +#include +#include +#include + +namespace ActsExamples { +struct AlgorithmContext; +} // namespace ActsExamples + +ActsExamples::SurfaceSortingAlgorithm::SurfaceSortingAlgorithm( + Config cfg, Acts::Logging::Level level) + : ActsExamples::IAlgorithm("SurfaceSortingAlgorithm", level), + m_cfg(std::move(cfg)) { + if (m_cfg.inputProtoTracks.empty()) { + throw std::invalid_argument("Missing input proto track collection"); + } + if (m_cfg.inputSimHits.empty()) { + throw std::invalid_argument("Missing input simulated hits collection"); + } + if (m_cfg.inputMeasurementSimHitsMap.empty()) { + throw std::invalid_argument("Missing input measurement sim hits map"); + } + if (m_cfg.outputProtoTracks.empty()) { + throw std::invalid_argument("Missing output proto track collection"); + } + + m_inputProtoTracks.initialize(m_cfg.inputProtoTracks); + m_inputSimHits.initialize(m_cfg.inputSimHits); + m_inputMeasurementSimHitsMap.initialize(m_cfg.inputMeasurementSimHitsMap); + m_outputProtoTracks.initialize(m_cfg.outputProtoTracks); +} + +ActsExamples::ProcessCode ActsExamples::SurfaceSortingAlgorithm::execute( + const ActsExamples::AlgorithmContext& ctx) const { + const auto& protoTracks = m_inputProtoTracks(ctx); + const auto& simHits = m_inputSimHits(ctx); + const auto& simHitsMap = m_inputMeasurementSimHitsMap(ctx); + + ProtoTrackContainer sortedTracks; + sortedTracks.reserve(protoTracks.size()); + TrackHitList trackHitList; + + for (std::size_t itrack = 0; itrack < protoTracks.size(); ++itrack) { + const auto& protoTrack = protoTracks[itrack]; + + ProtoTrack sortedProtoTrack; + sortedProtoTrack.reserve(protoTrack.size()); + trackHitList.clear(); + + if (protoTrack.empty()) { + continue; + } + + for (const auto hit : protoTrack) { + const auto simHitIndex = simHitsMap.find(hit)->second; + auto simHit = simHits.nth(simHitIndex); + auto simHitTime = simHit->time(); + trackHitList.insert(std::make_pair(simHitTime, hit)); + } + + /// Map will now be sorted by truth hit time + for (auto const& [time, hit] : trackHitList) { + sortedProtoTrack.emplace_back(hit); + } + + sortedTracks.emplace_back(std::move(sortedProtoTrack)); + } + + m_outputProtoTracks(ctx, std::move(sortedTracks)); + + return ActsExamples::ProcessCode::SUCCESS; +} diff --git a/source/tdis/tracking/TrackFitting/src/TrackFittingAlgorithm.cpp b/source/tdis/tracking/TrackFitting/src/TrackFittingAlgorithm.cpp new file mode 100644 index 0000000..50360e3 --- /dev/null +++ b/source/tdis/tracking/TrackFitting/src/TrackFittingAlgorithm.cpp @@ -0,0 +1,169 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp" + +#include "Acts/Definitions/Algebra.hpp" +#include "Acts/EventData/GenericBoundTrackParameters.hpp" +#include "Acts/EventData/SourceLink.hpp" +#include "Acts/EventData/TrackProxy.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/EventData/VectorTrackContainer.hpp" +#include "Acts/Propagator/Propagator.hpp" +#include "Acts/Surfaces/PerigeeSurface.hpp" +#include "Acts/Surfaces/Surface.hpp" +#include "Acts/Utilities/Result.hpp" +#include "ActsExamples/EventData/IndexSourceLink.hpp" +#include "ActsExamples/EventData/Measurement.hpp" +#include "ActsExamples/EventData/MeasurementCalibration.hpp" +#include "ActsExamples/EventData/ProtoTrack.hpp" +#include "ActsExamples/Framework/AlgorithmContext.hpp" +#include "ActsExamples/TrackFitting/TrackFitterFunction.hpp" + +#include +#include +#include +#include +#include +#include +#include + +ActsExamples::TrackFittingAlgorithm::TrackFittingAlgorithm( + Config config, Acts::Logging::Level level) + : ActsExamples::IAlgorithm("TrackFittingAlgorithm", level), + m_cfg(std::move(config)) { + if (m_cfg.inputMeasurements.empty()) { + throw std::invalid_argument("Missing input measurement collection"); + } + if (m_cfg.inputProtoTracks.empty()) { + throw std::invalid_argument("Missing input proto tracks collection"); + } + if (m_cfg.inputInitialTrackParameters.empty()) { + throw std::invalid_argument( + "Missing input initial track parameters collection"); + } + if (m_cfg.outputTracks.empty()) { + throw std::invalid_argument("Missing output tracks collection"); + } + if (!m_cfg.calibrator) { + throw std::invalid_argument("Missing calibrator"); + } + if (m_cfg.inputClusters.empty() && m_cfg.calibrator->needsClusters()) { + throw std::invalid_argument("The configured calibrator needs clusters"); + } + + m_inputMeasurements.initialize(m_cfg.inputMeasurements); + m_inputProtoTracks.initialize(m_cfg.inputProtoTracks); + m_inputInitialTrackParameters.initialize(m_cfg.inputInitialTrackParameters); + m_inputClusters.maybeInitialize(m_cfg.inputClusters); + m_outputTracks.initialize(m_cfg.outputTracks); +} + +ActsExamples::ProcessCode ActsExamples::TrackFittingAlgorithm::execute( + const ActsExamples::AlgorithmContext& ctx) const { + // Read input data + const auto& measurements = m_inputMeasurements(ctx); + const auto& protoTracks = m_inputProtoTracks(ctx); + const auto& initialParameters = m_inputInitialTrackParameters(ctx); + + const ClusterContainer* clusters = + m_inputClusters.isInitialized() ? &m_inputClusters(ctx) : nullptr; + + // Consistency cross checks + if (protoTracks.size() != initialParameters.size()) { + ACTS_FATAL("Inconsistent number of proto tracks and parameters " + << protoTracks.size() << " vs " << initialParameters.size()); + return ProcessCode::ABORT; + } + + // Construct a perigee surface as the target surface + auto pSurface = Acts::Surface::makeShared( + Acts::Vector3{0., 0., 0.}); + + // Measurement calibrator must be instantiated here, because we need the + // measurements to construct it. The other extensions are hold by the + // fit-function-object + ActsExamples::MeasurementCalibratorAdapter calibrator(*(m_cfg.calibrator), + measurements, clusters); + + TrackFitterFunction::GeneralFitterOptions options{ + ctx.geoContext, ctx.magFieldContext, ctx.calibContext, pSurface.get(), + Acts::PropagatorPlainOptions(ctx.geoContext, ctx.magFieldContext)}; + + auto trackContainer = std::make_shared(); + auto trackStateContainer = std::make_shared(); + TrackContainer tracks(trackContainer, trackStateContainer); + + // Perform the fit for each input track + std::vector trackSourceLinks; + for (std::size_t itrack = 0; itrack < protoTracks.size(); ++itrack) { + // Check if you are not in picking mode + if (m_cfg.pickTrack > -1 && m_cfg.pickTrack != static_cast(itrack)) { + continue; + } + + // The list of hits and the initial start parameters + const auto& protoTrack = protoTracks[itrack]; + const auto& initialParams = initialParameters[itrack]; + + // We can have empty tracks which must give empty fit results so the number + // of entries in input and output containers matches. + if (protoTrack.empty()) { + ACTS_WARNING("Empty track " << itrack << " found."); + continue; + } + + ACTS_VERBOSE("Initial parameters: " + << initialParams.fourPosition(ctx.geoContext).transpose() + << " -> " << initialParams.direction().transpose()); + + // Clear & reserve the right size + trackSourceLinks.clear(); + trackSourceLinks.reserve(protoTrack.size()); + + // Fill the source links via their indices from the container + for (auto measIndex : protoTrack) { + ConstVariableBoundMeasurementProxy measurement = + measurements.getMeasurement(measIndex); + IndexSourceLink sourceLink(measurement.geometryId(), measIndex); + trackSourceLinks.push_back(Acts::SourceLink(sourceLink)); + } + + ACTS_DEBUG("Invoke direct fitter for track " << itrack); + auto result = (*m_cfg.fit)(trackSourceLinks, initialParams, options, + calibrator, tracks); + + if (result.ok()) { + // Get the fit output object + const auto& track = result.value(); + if (track.hasReferenceSurface()) { + ACTS_VERBOSE("Fitted parameters for track " << itrack); + ACTS_VERBOSE(" " << track.parameters().transpose()); + } else { + ACTS_DEBUG("No fitted parameters for track " << itrack); + } + } else { + ACTS_WARNING("Fit failed for track " + << itrack << " with error: " << result.error() << ", " + << result.error().message()); + } + } + + std::stringstream ss; + trackStateContainer->statistics().toStream(ss); + ACTS_DEBUG(ss.str()); + + ConstTrackContainer constTracks{ + std::make_shared( + std::move(*trackContainer)), + std::make_shared( + std::move(*trackStateContainer))}; + + m_outputTracks(ctx, std::move(constTracks)); + return ActsExamples::ProcessCode::SUCCESS; +} diff --git a/source/tdis/tracking/TrackFittingAlgorithm.cpp b/source/tdis/tracking/TrackFittingAlgorithm.cpp new file mode 100644 index 0000000..a0ab70c --- /dev/null +++ b/source/tdis/tracking/TrackFittingAlgorithm.cpp @@ -0,0 +1,167 @@ +// This file is part of the ACTS project. +// +// Copyright (C) 2016 CERN for the benefit of the ACTS project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at https://mozilla.org/MPL/2.0/. + +#include "ActsExamples/TrackFitting/TrackFittingAlgorithm.hpp" + +#include "Acts/Definitions/Algebra.hpp" +#include "Acts/EventData/GenericBoundTrackParameters.hpp" +#include "Acts/EventData/SourceLink.hpp" +#include "Acts/EventData/TrackProxy.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/EventData/VectorTrackContainer.hpp" +#include "Acts/Propagator/Propagator.hpp" +#include "Acts/Surfaces/PerigeeSurface.hpp" +#include "Acts/Surfaces/Surface.hpp" +#include "Acts/Utilities/Result.hpp" +#include "ActsExamples/EventData/IndexSourceLink.hpp" +#include "ActsExamples/EventData/Measurement.hpp" +#include "ActsExamples/EventData/MeasurementCalibration.hpp" +#include "ActsExamples/EventData/ProtoTrack.hpp" +#include "ActsExamples/Framework/AlgorithmContext.hpp" +#include "ActsExamples/TrackFitting/TrackFitterFunction.hpp" + +#include +#include +#include +#include +#include +#include +#include + +ActsExamples::TrackFittingAlgorithm::TrackFittingAlgorithm(Config config, Acts::Logging::Level level): ActsExamples::IAlgorithm("TrackFittingAlgorithm", level), + m_cfg(std::move(config)) { + if (m_cfg.inputMeasurements.empty()) { + throw std::invalid_argument("Missing input measurement collection"); + } + if (m_cfg.inputProtoTracks.empty()) { + throw std::invalid_argument("Missing input proto tracks collection"); + } + if (m_cfg.inputInitialTrackParameters.empty()) { + throw std::invalid_argument( + "Missing input initial track parameters collection"); + } + if (m_cfg.outputTracks.empty()) { + throw std::invalid_argument("Missing output tracks collection"); + } + if (!m_cfg.calibrator) { + throw std::invalid_argument("Missing calibrator"); + } + if (m_cfg.inputClusters.empty() && m_cfg.calibrator->needsClusters()) { + throw std::invalid_argument("The configured calibrator needs clusters"); + } + + m_inputMeasurements.initialize(m_cfg.inputMeasurements); + m_inputProtoTracks.initialize(m_cfg.inputProtoTracks); + m_inputInitialTrackParameters.initialize(m_cfg.inputInitialTrackParameters); + m_inputClusters.maybeInitialize(m_cfg.inputClusters); + m_outputTracks.initialize(m_cfg.outputTracks); +} + +ActsExamples::ProcessCode ActsExamples::TrackFittingAlgorithm::execute( + const ActsExamples::AlgorithmContext& ctx) const { + // Read input data + const auto& measurements = m_inputMeasurements(ctx); + const auto& protoTracks = m_inputProtoTracks(ctx); + const auto& initialParameters = m_inputInitialTrackParameters(ctx); + + const ClusterContainer* clusters = + m_inputClusters.isInitialized() ? &m_inputClusters(ctx) : nullptr; + + // Consistency cross checks + if (protoTracks.size() != initialParameters.size()) { + ACTS_FATAL("Inconsistent number of proto tracks and parameters " + << protoTracks.size() << " vs " << initialParameters.size()); + return ProcessCode::ABORT; + } + + // Construct a perigee surface as the target surface + auto pSurface = Acts::Surface::makeShared( + Acts::Vector3{0., 0., 0.}); + + // Measurement calibrator must be instantiated here, because we need the + // measurements to construct it. The other extensions are hold by the + // fit-function-object + ActsExamples::MeasurementCalibratorAdapter calibrator(*(m_cfg.calibrator), + measurements, clusters); + + TrackFitterFunction::GeneralFitterOptions options{ + ctx.geoContext, ctx.magFieldContext, ctx.calibContext, pSurface.get(), + Acts::PropagatorPlainOptions(ctx.geoContext, ctx.magFieldContext)}; + + auto trackContainer = std::make_shared(); + auto trackStateContainer = std::make_shared(); + TrackContainer tracks(trackContainer, trackStateContainer); + + // Perform the fit for each input track + std::vector trackSourceLinks; + for (std::size_t itrack = 0; itrack < protoTracks.size(); ++itrack) { + // Check if you are not in picking mode + if (m_cfg.pickTrack > -1 && m_cfg.pickTrack != static_cast(itrack)) { + continue; + } + + // The list of hits and the initial start parameters + const auto& protoTrack = protoTracks[itrack]; + const auto& initialParams = initialParameters[itrack]; + + // We can have empty tracks which must give empty fit results so the number + // of entries in input and output containers matches. + if (protoTrack.empty()) { + ACTS_WARNING("Empty track " << itrack << " found."); + continue; + } + + ACTS_VERBOSE("Initial parameters: " + << initialParams.fourPosition(ctx.geoContext).transpose() + << " -> " << initialParams.direction().transpose()); + + // Clear & reserve the right size + trackSourceLinks.clear(); + trackSourceLinks.reserve(protoTrack.size()); + + // Fill the source links via their indices from the container + for (auto measIndex : protoTrack) { + ConstVariableBoundMeasurementProxy measurement = + measurements.getMeasurement(measIndex); + IndexSourceLink sourceLink(measurement.geometryId(), measIndex); + trackSourceLinks.push_back(Acts::SourceLink(sourceLink)); + } + + ACTS_DEBUG("Invoke direct fitter for track " << itrack); + auto result = (*m_cfg.fit)(trackSourceLinks, initialParams, options, + calibrator, tracks); + + if (result.ok()) { + // Get the fit output object + const auto& track = result.value(); + if (track.hasReferenceSurface()) { + ACTS_VERBOSE("Fitted parameters for track " << itrack); + ACTS_VERBOSE(" " << track.parameters().transpose()); + } else { + ACTS_DEBUG("No fitted parameters for track " << itrack); + } + } else { + ACTS_WARNING("Fit failed for track " + << itrack << " with error: " << result.error() << ", " + << result.error().message()); + } + } + + std::stringstream ss; + trackStateContainer->statistics().toStream(ss); + ACTS_DEBUG(ss.str()); + + ConstTrackContainer constTracks{ + std::make_shared( + std::move(*trackContainer)), + std::make_shared( + std::move(*trackStateContainer))}; + + m_outputTracks(ctx, std::move(constTracks)); + return ActsExamples::ProcessCode::SUCCESS; +}