From 2ed5a9327a3fa317572be3500185f7db6ca13ac6 Mon Sep 17 00:00:00 2001 From: "Alexander J. Pfleger" <70842573+AJPfleger@users.noreply.github.com> Date: Mon, 23 Oct 2023 17:11:22 +0200 Subject: [PATCH] feat: python bindings and truth tracking example for GX2F (#2512) This adds a basic python framework to the Global Chi Square Fitter (GX2F). It runs with the current GX2F-implementation and is created to test the GX2F further. So far some of the pulls already go into the right direction. Note, that we cannot fit with B-Fields != 0. ![Canvas](https://github.com/acts-project/acts/assets/70842573/8a52aa7a-7fec-4137-ac8b-851b35cde6e3) --- .../Algorithms/TrackFitting/CMakeLists.txt | 3 +- .../TrackFitting/TrackFitterFunction.hpp | 17 ++ .../src/GlobalChiSquareFitterFunction.cpp | 162 +++++++++++++++++ .../python/acts/examples/reconstruction.py | 49 ++++++ Examples/Python/src/TrackFitting.cpp | 20 ++- .../Scripts/Python/truth_tracking_gx2f.py | 166 ++++++++++++++++++ 6 files changed, 415 insertions(+), 2 deletions(-) create mode 100644 Examples/Algorithms/TrackFitting/src/GlobalChiSquareFitterFunction.cpp create mode 100644 Examples/Scripts/Python/truth_tracking_gx2f.py diff --git a/Examples/Algorithms/TrackFitting/CMakeLists.txt b/Examples/Algorithms/TrackFitting/CMakeLists.txt index 7b6ebecc175..ca049947e51 100644 --- a/Examples/Algorithms/TrackFitting/CMakeLists.txt +++ b/Examples/Algorithms/TrackFitting/CMakeLists.txt @@ -5,7 +5,8 @@ add_library( src/TrackFittingAlgorithm.cpp src/KalmanFitterFunction.cpp src/RefittingAlgorithm.cpp - src/GsfFitterFunction.cpp) + src/GsfFitterFunction.cpp + src/GlobalChiSquareFitterFunction.cpp) target_include_directories( ActsExamplesTrackFitting PUBLIC $) diff --git a/Examples/Algorithms/TrackFitting/include/ActsExamples/TrackFitting/TrackFitterFunction.hpp b/Examples/Algorithms/TrackFitting/include/ActsExamples/TrackFitting/TrackFitterFunction.hpp index de0dcf09650..cfbb31c24e4 100644 --- a/Examples/Algorithms/TrackFitting/include/ActsExamples/TrackFitting/TrackFitterFunction.hpp +++ b/Examples/Algorithms/TrackFitting/include/ActsExamples/TrackFitting/TrackFitterFunction.hpp @@ -92,4 +92,21 @@ std::shared_ptr makeGsfFitterFunction( bool abortOnError, bool disableAllMaterialHandling, const Acts::Logger& logger); +/// Makes a fitter function object for the Global Chi Square Fitter (GX2F) +/// +/// @param trackingGeometry the trackingGeometry for the propagator +/// @param magneticField the magnetic field for the propagator +/// @param multipleScattering bool +/// @param energyLoss bool +/// @param freeToBoundCorrection bool +/// @param logger a logger instance +std::shared_ptr makeGlobalChiSquareFitterFunction( + std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + bool multipleScattering = true, bool energyLoss = true, + Acts::FreeToBoundCorrection freeToBoundCorrection = + Acts::FreeToBoundCorrection(), + const Acts::Logger& logger = *Acts::getDefaultLogger("Gx2f", + Acts::Logging::INFO)); + } // namespace ActsExamples diff --git a/Examples/Algorithms/TrackFitting/src/GlobalChiSquareFitterFunction.cpp b/Examples/Algorithms/TrackFitting/src/GlobalChiSquareFitterFunction.cpp new file mode 100644 index 00000000000..a641f93de0e --- /dev/null +++ b/Examples/Algorithms/TrackFitting/src/GlobalChiSquareFitterFunction.cpp @@ -0,0 +1,162 @@ +// This file is part of the Acts project. +// +// Copyright (C) 2023 CERN for the benefit of the Acts project +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this +// file, You can obtain one at http://mozilla.org/MPL/2.0/. + +// TODO We still use some Kalman Fitter functionalities. Check for replacement + +#include "Acts/Definitions/Direction.hpp" +#include "Acts/Definitions/TrackParametrization.hpp" +#include "Acts/EventData/MultiTrajectory.hpp" +#include "Acts/EventData/TrackContainer.hpp" +#include "Acts/EventData/TrackStatePropMask.hpp" +#include "Acts/EventData/VectorMultiTrajectory.hpp" +#include "Acts/EventData/VectorTrackContainer.hpp" +#include "Acts/EventData/detail/CorrectedTransformationFreeToBound.hpp" +#include "Acts/Geometry/GeometryIdentifier.hpp" +#include "Acts/Propagator/DirectNavigator.hpp" +#include "Acts/Propagator/EigenStepper.hpp" +#include "Acts/Propagator/Navigator.hpp" +#include "Acts/Propagator/Propagator.hpp" +#include "Acts/TrackFitting/GlobalChiSquareFitter.hpp" +#include "Acts/TrackFitting/KalmanFitter.hpp" +#include "Acts/Utilities/Delegate.hpp" +#include "Acts/Utilities/Logger.hpp" +#include "ActsExamples/EventData/IndexSourceLink.hpp" +#include "ActsExamples/EventData/MeasurementCalibration.hpp" +#include "ActsExamples/EventData/Track.hpp" +#include "ActsExamples/TrackFitting/RefittingCalibrator.hpp" +#include "ActsExamples/TrackFitting/TrackFitterFunction.hpp" + +#include +#include +#include +#include +#include +#include + +namespace Acts { +class MagneticFieldProvider; +class SourceLink; +class Surface; +class TrackingGeometry; +} // namespace Acts + +namespace { + +using Stepper = Acts::EigenStepper<>; +using Propagator = Acts::Propagator; +using Fitter = + Acts::Experimental::Gx2Fitter; +using DirectPropagator = Acts::Propagator; +using DirectFitter = + Acts::KalmanFitter; + +using TrackContainer = + Acts::TrackContainer; + +using namespace ActsExamples; + +struct GlobalChiSquareFitterFunctionImpl final : public TrackFitterFunction { + Fitter fitter; + DirectFitter directFitter; + + bool multipleScattering = false; + bool energyLoss = false; + Acts::FreeToBoundCorrection freeToBoundCorrection; + + IndexSourceLink::SurfaceAccessor m_slSurfaceAccessor; + + GlobalChiSquareFitterFunctionImpl(Fitter&& f, DirectFitter&& df, + const Acts::TrackingGeometry& trkGeo) + : fitter(std::move(f)), + directFitter(std::move(df)), + m_slSurfaceAccessor{trkGeo} {} + + template + auto makeGx2fOptions(const GeneralFitterOptions& options, + const calibrator_t& calibrator) const { + Acts::Experimental::Gx2FitterExtensions + extensions; + extensions.calibrator.connect<&calibrator_t::calibrate>(&calibrator); + + extensions.surfaceAccessor + .connect<&IndexSourceLink::SurfaceAccessor::operator()>( + &m_slSurfaceAccessor); + + const Acts::Experimental::Gx2FitterOptions gx2fOptions( + options.geoContext, options.magFieldContext, options.calibrationContext, + extensions, options.propOptions, &(*options.referenceSurface), + multipleScattering, energyLoss, freeToBoundCorrection, 5); + + return gx2fOptions; + } + + TrackFitterResult operator()(const std::vector& sourceLinks, + const TrackParameters& initialParameters, + const GeneralFitterOptions& options, + const MeasurementCalibratorAdapter& calibrator, + TrackContainer& tracks) const override { + const auto gx2fOptions = makeGx2fOptions(options, calibrator); + return fitter.fit(sourceLinks.begin(), sourceLinks.end(), initialParameters, + gx2fOptions, tracks); + } + + // We need a placeholder for the directNavigator overload. Otherwise, we would + // have an unimplemented pure virtual method in a final class. + TrackFitterResult operator()( + const std::vector& /*sourceLinks*/, + const TrackParameters& /*initialParameters*/, + const GeneralFitterOptions& /*options*/, + const RefittingCalibrator& /*calibrator*/, + const std::vector& /*surfaceSequence*/, + TrackContainer& /*tracks*/) const override { + throw std::runtime_error( + "direct navigation with GX2 fitter is not implemented"); + } +}; + +} // namespace + +std::shared_ptr +ActsExamples::makeGlobalChiSquareFitterFunction( + std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + bool multipleScattering, bool energyLoss, + Acts::FreeToBoundCorrection freeToBoundCorrection, + const Acts::Logger& logger) { + // Stepper should be copied into the fitters + const Stepper stepper(std::move(magneticField)); + + // Standard fitter + const auto& geo = *trackingGeometry; + Acts::Navigator::Config cfg{std::move(trackingGeometry)}; + cfg.resolvePassive = false; + cfg.resolveMaterial = true; + cfg.resolveSensitive = true; + Acts::Navigator navigator(cfg, logger.cloneWithSuffix("Navigator")); + Propagator propagator(stepper, std::move(navigator), + logger.cloneWithSuffix("Propagator")); + Fitter trackFitter(std::move(propagator), logger.cloneWithSuffix("Fitter")); + + // Direct fitter + Acts::DirectNavigator directNavigator{ + logger.cloneWithSuffix("DirectNavigator")}; + DirectPropagator directPropagator(stepper, std::move(directNavigator), + logger.cloneWithSuffix("DirectPropagator")); + DirectFitter directTrackFitter(std::move(directPropagator), + logger.cloneWithSuffix("DirectFitter")); + + // build the fitter function. owns the fitter object. + auto fitterFunction = std::make_shared( + std::move(trackFitter), std::move(directTrackFitter), geo); + fitterFunction->multipleScattering = multipleScattering; + fitterFunction->energyLoss = energyLoss; + fitterFunction->freeToBoundCorrection = freeToBoundCorrection; + + return fitterFunction; +} diff --git a/Examples/Python/python/acts/examples/reconstruction.py b/Examples/Python/python/acts/examples/reconstruction.py index 42dfa4f4ab6..27e2f3731d5 100644 --- a/Examples/Python/python/acts/examples/reconstruction.py +++ b/Examples/Python/python/acts/examples/reconstruction.py @@ -1134,6 +1134,55 @@ def addCKFTracks( return s +def addGx2fTracks( + s: acts.examples.Sequencer, + trackingGeometry: acts.TrackingGeometry, + field: acts.MagneticFieldProvider, + # directNavigation: bool = False, + inputProtoTracks: str = "truth_particle_tracks", + multipleScattering: bool = False, + energyLoss: bool = False, + clusters: str = None, + calibrator: acts.examples.MeasurementCalibrator = acts.examples.makePassThroughCalibrator(), + logLevel: Optional[acts.logging.Level] = None, +) -> None: + customLogLevel = acts.examples.defaultLogging(s, logLevel) + + gx2fOptions = { + "multipleScattering": multipleScattering, + "energyLoss": energyLoss, + "freeToBoundCorrection": acts.examples.FreeToBoundCorrection(False), + "level": customLogLevel(), + } + + fitAlg = acts.examples.TrackFittingAlgorithm( + level=customLogLevel(), + inputMeasurements="measurements", + inputSourceLinks="sourcelinks", + inputProtoTracks=inputProtoTracks, + inputInitialTrackParameters="estimatedparameters", + inputClusters=clusters if clusters is not None else "", + outputTracks="gx2fTracks", + pickTrack=-1, + fit=acts.examples.makeGlobalChiSquareFitterFunction( + trackingGeometry, field, **gx2fOptions + ), + calibrator=calibrator, + ) + s.addAlgorithm(fitAlg) + s.addWhiteboardAlias("tracks", fitAlg.config.outputTracks) + + trackConverter = acts.examples.TracksToTrajectories( + level=customLogLevel(), + inputTracks=fitAlg.config.outputTracks, + outputTrajectories="gx2fTrajectories", + ) + s.addAlgorithm(trackConverter) + s.addWhiteboardAlias("trajectories", trackConverter.config.outputTrajectories) + + return s + + def addTrajectoryWriters( s: acts.examples.Sequencer, name: str, diff --git a/Examples/Python/src/TrackFitting.cpp b/Examples/Python/src/TrackFitting.cpp index 8dcf9f0e917..6451cba4d45 100644 --- a/Examples/Python/src/TrackFitting.cpp +++ b/Examples/Python/src/TrackFitting.cpp @@ -1,6 +1,6 @@ // This file is part of the Acts project. // -// Copyright (C) 2021 CERN for the benefit of the Acts project +// Copyright (C) 2021-2023 CERN for the benefit of the Acts project // // This Source Code Form is subject to the terms of the Mozilla Public // License, v. 2.0. If a copy of the MPL was not distributed with this @@ -127,6 +127,24 @@ void addTrackFitting(Context& ctx) { py::arg("weightCutoff"), py::arg("finalReductionMethod"), py::arg("abortOnError"), py::arg("disableAllMaterialHandling"), py::arg("level")); + + mex.def( + "makeGlobalChiSquareFitterFunction", + [](std::shared_ptr trackingGeometry, + std::shared_ptr magneticField, + bool multipleScattering, bool energyLoss, + Acts::FreeToBoundCorrection freeToBoundCorrection, + Logging::Level level) { + return ActsExamples::makeGlobalChiSquareFitterFunction( + trackingGeometry, magneticField, multipleScattering, energyLoss, + freeToBoundCorrection, *Acts::getDefaultLogger("Gx2f", level)); + }, + py::arg("trackingGeometry"), py::arg("magneticField"), + py::arg("multipleScattering"), py::arg("energyLoss"), + py::arg("freeToBoundCorrection"), py::arg("level")); + + // TODO add other important parameters like nUpdates + // TODO add also in trackfitterfunction } { diff --git a/Examples/Scripts/Python/truth_tracking_gx2f.py b/Examples/Scripts/Python/truth_tracking_gx2f.py new file mode 100644 index 00000000000..910fc05718a --- /dev/null +++ b/Examples/Scripts/Python/truth_tracking_gx2f.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 + +from pathlib import Path +from typing import Optional + +import acts +import acts.examples + +u = acts.UnitConstants + + +def runTruthTrackingGx2f( + trackingGeometry: acts.TrackingGeometry, + field: acts.MagneticFieldProvider, + outputDir: Path, + digiConfigFile: Path, + s: acts.examples.Sequencer = None, + inputParticlePath: Optional[Path] = None, +): + from acts.examples.simulation import ( + addParticleGun, + MomentumConfig, + EtaConfig, + ParticleConfig, + addFatras, + addDigitization, + ) + from acts.examples.reconstruction import ( + addSeeding, + SeedingAlgorithm, + TruthSeedRanges, + addGx2fTracks, + ) + + s = s or acts.examples.Sequencer( + events=10000, numThreads=-1, logLevel=acts.logging.INFO + ) + + rnd = acts.examples.RandomNumbers() + outputDir = Path(outputDir) + + if inputParticlePath is None: + addParticleGun( + s, + MomentumConfig(100.0 * u.GeV, 100.0 * u.GeV, transverse=True), + EtaConfig(-2.0, 2.0), + ParticleConfig(2, acts.PdgParticle.eMuon, False), + multiplicity=1, + rnd=rnd, + outputDirRoot=outputDir, + ) + else: + acts.logging.getLogger("Truth tracking example").info( + "Reading particles from %s", inputParticlePath.resolve() + ) + assert inputParticlePath.exists() + s.addReader( + RootParticleReader( + level=acts.logging.INFO, + filePath=str(inputParticlePath.resolve()), + particleCollection="particles_input", + orderedEvents=False, + ) + ) + + addFatras( + s, + trackingGeometry, + field, + rnd=rnd, + enableInteractions=True, + ) + + addDigitization( + s, + trackingGeometry, + field, + digiConfigFile=digiConfigFile, + rnd=rnd, + ) + + addSeeding( + s, + trackingGeometry, + field, + seedingAlgorithm=SeedingAlgorithm.TruthSmeared, + rnd=rnd, + truthSeedRanges=TruthSeedRanges( + pt=(1 * u.GeV, None), + nHits=(9, None), + ), + ) + + addGx2fTracks( + s, + trackingGeometry, + field, + # directNavigation, + ) + + # Output + s.addWriter( + acts.examples.RootTrajectoryStatesWriter( + level=acts.logging.INFO, + inputTrajectories="trajectories", + inputParticles="truth_seeds_selected", + inputSimHits="simhits", + inputMeasurementParticlesMap="measurement_particles_map", + inputMeasurementSimHitsMap="measurement_simhits_map", + filePath=str(outputDir / "trackstates_fitter.root"), + ) + ) + + s.addWriter( + acts.examples.RootTrajectorySummaryWriter( + level=acts.logging.INFO, + inputTrajectories="trajectories", + inputParticles="truth_seeds_selected", + inputMeasurementParticlesMap="measurement_particles_map", + filePath=str(outputDir / "tracksummary_fitter.root"), + ) + ) + + # TODO: PerformanceWriters are not tested yet + # s.addWriter( + # acts.examples.TrackFinderPerformanceWriter( + # level=acts.logging.INFO, + # inputProtoTracks="truth_particle_tracks", + # inputParticles="truth_seeds_selected", + # inputMeasurementParticlesMap="measurement_particles_map", + # filePath=str(outputDir / "performance_track_finder.root"), + # ) + # ) + # + # s.addWriter( + # acts.examples.TrackFitterPerformanceWriter( + # level=acts.logging.INFO, + # inputTrajectories="trajectories", + # inputParticles="truth_seeds_selected", + # inputMeasurementParticlesMap="measurement_particles_map", + # filePath=str(outputDir / "performance_track_fitter.root"), + # ) + # ) + + return s + + +if "__main__" == __name__: + srcdir = Path(__file__).resolve().parent.parent.parent.parent + + # detector, trackingGeometry, _ = getOpenDataDetector() + detector, trackingGeometry, decorators = acts.examples.GenericDetector.create() + + field = acts.ConstantBField(acts.Vector3(0, 0, 2 * u.T)) + + runTruthTrackingGx2f( + trackingGeometry=trackingGeometry, + # decorators=decorators, + field=field, + digiConfigFile=srcdir + / "Examples/Algorithms/Digitization/share/default-smearing-config-generic.json", + # "thirdparty/OpenDataDetector/config/odd-digi-smearing-config.json", + # outputCsv=True, + # inputParticlePath=inputParticlePath, + outputDir=Path.cwd(), + ).run()