diff --git a/core/include/traccc/edm/track_parameters.hpp b/core/include/traccc/edm/track_parameters.hpp index c122600ef..6e4135277 100644 --- a/core/include/traccc/edm/track_parameters.hpp +++ b/core/include/traccc/edm/track_parameters.hpp @@ -48,20 +48,25 @@ inline void wrap_phi(bound_track_parameters& param) { param.set_phi(phi); } -/// Covariance inflation used for track fitting -TRACCC_HOST_DEVICE -inline void inflate_covariance(bound_track_parameters& param, - const traccc::scalar inf_fac) { - auto& cov = param.covariance(); - for (unsigned int i = 0; i < e_bound_size; i++) { - for (unsigned int j = 0; j < e_bound_size; j++) { - if (i == j) { - getter::element(cov, i, i) *= inf_fac; - } else { - getter::element(cov, i, j) = 0.f; +/// Struct for covariance inflation +struct covariance_inflator { + + traccc::scalar inf_fac{1.f}; + + /// Covariance inflation used for track fitting + TRACCC_HOST_DEVICE + inline void operator()(bound_track_parameters& param) { + auto& cov = param.covariance(); + for (unsigned int i = 0; i < e_bound_size; i++) { + for (unsigned int j = 0; j < e_bound_size; j++) { + if (i == j) { + getter::element(cov, i, i) *= inf_fac; + } else { + getter::element(cov, i, j) = 0.f; + } } } } -} +}; } // namespace traccc diff --git a/core/include/traccc/finding/details/find_tracks.hpp b/core/include/traccc/finding/details/find_tracks.hpp index 99ba964c0..4887f30b6 100644 --- a/core/include/traccc/finding/details/find_tracks.hpp +++ b/core/include/traccc/finding/details/find_tracks.hpp @@ -139,6 +139,10 @@ track_candidate_container_types::host find_tracks( // Copy seed to input parameters std::vector in_params(seeds.size()); std::copy(seeds.begin(), seeds.end(), in_params.begin()); + + // Inflate the covariance + covariance_inflator cov_inf(config.covariance_inflation_factor); + std::for_each(in_params.begin(), in_params.end(), cov_inf); std::vector n_trks_per_seed(seeds.size()); std::vector out_params; diff --git a/core/include/traccc/finding/finding_config.hpp b/core/include/traccc/finding/finding_config.hpp index a5fe7bfce..5432bc15d 100644 --- a/core/include/traccc/finding/finding_config.hpp +++ b/core/include/traccc/finding/finding_config.hpp @@ -48,6 +48,9 @@ struct finding_config { /// Particle hypothesis detray::pdg_particle ptc_hypothesis = detray::muon(); + + /// Covariance inflation + traccc::scalar covariance_inflation_factor = 1e3f; }; } // namespace traccc diff --git a/core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp b/core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp index 9dccce732..d11457f92 100644 --- a/core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp +++ b/core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp @@ -153,8 +153,8 @@ class kalman_fitter { : fitter_state.m_fit_actor_state.m_track_states[0] .smoothed(); - inflate_covariance(seed_params_cpy, - m_cfg.covariance_inflation_factor); + covariance_inflator cov_inf(m_cfg.covariance_inflation_factor); + cov_inf(seed_params_cpy); filter(seed_params_cpy, fitter_state); } @@ -234,8 +234,8 @@ class kalman_fitter { typename backward_propagator_type::state propagation( last.smoothed(), m_field, m_detector); - inflate_covariance(propagation._stepping.bound_params(), - m_cfg.covariance_inflation_factor); + covariance_inflator cov_inf(m_cfg.covariance_inflation_factor); + cov_inf(propagation._stepping.bound_params()); propagation._navigation.set_volume( last.smoothed().surface_link().volume()); diff --git a/device/cuda/src/finding/finding_algorithm.cu b/device/cuda/src/finding/finding_algorithm.cu index 96ff6bd76..0183e2de8 100644 --- a/device/cuda/src/finding/finding_algorithm.cu +++ b/device/cuda/src/finding/finding_algorithm.cu @@ -154,6 +154,14 @@ finding_algorithm::operator()( m_copy.setup(in_params_buffer)->ignore(); m_copy(vecmem::get_data(seeds_buffer), vecmem::get_data(in_params_buffer)) ->ignore(); + + // Inflate the covariance + bound_track_parameters_collection_types::device in_params_device( + in_params_buffer); + covariance_inflator cov_inf(m_cfg.covariance_inflation_factor); + thrust::for_each(thrust_policy, in_params_device.begin(), + in_params_device.end(), cov_inf); + vecmem::data::vector_buffer param_liveness_buffer(n_seeds, m_mr.main); m_copy.setup(param_liveness_buffer)->ignore(); diff --git a/device/sycl/src/finding/find_tracks.hpp b/device/sycl/src/finding/find_tracks.hpp index 4955b884a..e68365ecc 100644 --- a/device/sycl/src/finding/find_tracks.hpp +++ b/device/sycl/src/finding/find_tracks.hpp @@ -172,6 +172,17 @@ track_candidate_container_types::buffer find_tracks( mr.main); copy.setup(in_params_buffer)->wait(); copy(seeds, in_params_buffer, vecmem::copy::type::device_to_device)->wait(); + + // Inflate the covariance + bound_track_parameters_collection_types::device in_params_device( + in_params_buffer); + oneapi::dpl::for_each( + policy, in_params_device.begin(), in_params_device.end(), + [config](bound_track_parameters& params) { + covariance_inflator cov_inf(config.covariance_inflation_factor); + cov_inf(params); + }); + vecmem::data::vector_buffer param_liveness_buffer(n_seeds, mr.main); copy.setup(param_liveness_buffer)->wait(); diff --git a/examples/options/src/track_finding.cpp b/examples/options/src/track_finding.cpp index f0a7aff94..b3f919576 100644 --- a/examples/options/src/track_finding.cpp +++ b/examples/options/src/track_finding.cpp @@ -69,6 +69,11 @@ track_finding::track_finding() : interface("Track Finding Options") { m_desc.add_options()("particle-hypothesis", po::value(&m_pdg_number)->default_value(m_pdg_number), "PDG number for the particle hypothesis"); + m_desc.add_options()( + "find-covariance-inflation-factor", + po::value(&m_config.covariance_inflation_factor) + ->default_value(m_config.covariance_inflation_factor), + "Covariance inflation factor for the track finding"); } track_finding::operator finding_config() const { @@ -111,6 +116,9 @@ std::unique_ptr track_finding::as_printable() const { std::to_string(m_config.max_num_skipping_per_cand))); cat->add_child(std::make_unique( "PDG number", std::to_string(m_pdg_number))); + cat->add_child(std::make_unique( + "Covariance inflation factor", + std::to_string(m_config.covariance_inflation_factor))); return cat; }