Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Undo some changes from #844 #881

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions benchmarks/common/benchmarks/toy_detector_benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@

// Detray include(s).
#include <detray/detectors/bfield.hpp>
#include <detray/io/frontend/detector_reader.hpp>
#include <detray/io/frontend/detector_writer.hpp>
#include <detray/navigation/navigator.hpp>
#include <detray/propagator/propagator.hpp>
#include <detray/propagator/rk_stepper.hpp>
#include <detray/test/utils/detectors/build_toy_detector.hpp>
#include <detray/test/utils/simulation/event_generator/track_generators.hpp>
#include <detray/tracks/ray.hpp>
Expand Down
8 changes: 3 additions & 5 deletions benchmarks/cpu/toy_detector_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,16 @@
#include "benchmarks/toy_detector_benchmark.hpp"

// Detray include(s).
#include <detray/core/detector.hpp>
#include <detray/detectors/bfield.hpp>
#include <detray/io/frontend/detector_reader.hpp>
#include <detray/navigation/navigator.hpp>
#include <detray/propagator/propagator.hpp>
#include <detray/propagator/rk_stepper.hpp>

// VecMem include(s).
#include <vecmem/memory/host_memory_resource.hpp>

// Google benchmark include(s).
#include <benchmark/benchmark.h>

BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) {
BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) {

// Type declarations
using host_detector_type = traccc::default_detector::host;
Expand Down Expand Up @@ -88,3 +84,5 @@ BENCHMARK_F(ToyDetectorBenchmark, CPU)(benchmark::State& state) {
state.counters["event_throughput_Hz"] = benchmark::Counter(
static_cast<double>(n_events), benchmark::Counter::kIsRate);
}

BENCHMARK_REGISTER_F(ToyDetectorBenchmark, CPU)->UseRealTime();
6 changes: 3 additions & 3 deletions benchmarks/cuda/toy_detector_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,9 @@
#include "benchmarks/toy_detector_benchmark.hpp"

// Detray include(s).
#include <detray/core/detector.hpp>
#include <detray/detectors/bfield.hpp>
#include <detray/io/frontend/detector_reader.hpp>
#include <detray/navigation/navigator.hpp>
#include <detray/propagator/propagator.hpp>
#include <detray/propagator/rk_stepper.hpp>

// VecMem include(s).
Expand All @@ -39,7 +37,7 @@
// Google benchmark include(s).
#include <benchmark/benchmark.h>

BENCHMARK_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {
BENCHMARK_DEFINE_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {

// Type declarations
using rk_stepper_type = detray::rk_stepper<
Expand Down Expand Up @@ -163,3 +161,5 @@ BENCHMARK_F(ToyDetectorBenchmark, CUDA)(benchmark::State& state) {
state.counters["event_throughput_Hz"] = benchmark::Counter(
static_cast<double>(n_events), benchmark::Counter::kIsRate);
}

BENCHMARK_REGISTER_F(ToyDetectorBenchmark, CUDA)->UseRealTime();
6 changes: 4 additions & 2 deletions core/include/traccc/finding/details/find_tracks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "traccc/finding/candidate_link.hpp"
#include "traccc/finding/finding_config.hpp"
#include "traccc/fitting/kalman_filter/gain_matrix_updater.hpp"
#include "traccc/fitting/status_codes.hpp"
#include "traccc/sanity/contiguous_on.hpp"
#include "traccc/utils/particle.hpp"
#include "traccc/utils/projections.hpp"
Expand Down Expand Up @@ -236,14 +237,15 @@ track_candidate_container_types::host find_tracks(
track_state<algebra_type> trk_state(meas);

// Run the Kalman update on a copy of the track parameters
const bool res =
const kalman_fitter_status res =
sf.template visit_mask<gain_matrix_updater<algebra_type>>(
trk_state, in_param);

const traccc::scalar chi2 = trk_state.filtered_chi2();

// The chi2 from Kalman update should be less than chi2_max
if (res && chi2 < config.chi2_max) {
if (res == kalman_fitter_status::SUCCESS &&
chi2 < config.chi2_max) {
n_branches++;

links[step].push_back({{previous_step, in_param_id},
Expand Down
14 changes: 10 additions & 4 deletions core/include/traccc/fitting/details/fit_tracks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// Project include(s).
#include "traccc/edm/track_candidate.hpp"
#include "traccc/edm/track_state.hpp"
#include "traccc/fitting/status_codes.hpp"

// VecMem include(s).
#include <vecmem/memory/memory_resource.hpp>
Expand Down Expand Up @@ -61,11 +62,16 @@ track_state_container_types::host fit_tracks(
typename fitter_t::state fitter_state(vecmem::get_data(input_states));

// Run the fitter.
fitter.fit(track_candidates.get_headers()[i].seed_params, fitter_state);
kalman_fitter_status fit_status = fitter.fit(
track_candidates.get_headers()[i].seed_params, fitter_state);

// Save the results into the output container.
result.push_back(std::move(fitter_state.m_fit_res),
std::move(input_states));
if (fit_status == kalman_fitter_status::SUCCESS) {
// Save the results into the output container.
result.push_back(std::move(fitter_state.m_fit_res),
std::move(input_states));
} else {
// TODO: Print a warning here.
}
}

// Return the fitted track states.
Expand Down
28 changes: 23 additions & 5 deletions core/include/traccc/fitting/kalman_filter/gain_matrix_smoother.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "traccc/definitions/qualifiers.hpp"
#include "traccc/edm/track_parameters.hpp"
#include "traccc/edm/track_state.hpp"
#include "traccc/fitting/status_codes.hpp"

// Detray inlcude(s)
#include <detray/geometry/shapes/line.hpp>
Expand Down Expand Up @@ -43,7 +44,7 @@ struct gain_matrix_smoother {
///
/// @return true if the update succeeds
template <typename mask_group_t, typename index_t>
TRACCC_HOST_DEVICE inline void operator()(
TRACCC_HOST_DEVICE [[nodiscard]] inline kalman_fitter_status operator()(
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
track_state<algebra_t>& cur_state,
const track_state<algebra_t>& next_state) {
Expand All @@ -53,14 +54,16 @@ struct gain_matrix_smoother {
const auto D = cur_state.get_measurement().meas_dim;
assert(D == 1u || D == 2u);
if (D == 1u) {
smoothe<1u, shape_type>(cur_state, next_state);
return smoothe<1u, shape_type>(cur_state, next_state);
} else if (D == 2u) {
smoothe<2u, shape_type>(cur_state, next_state);
return smoothe<2u, shape_type>(cur_state, next_state);
}

return kalman_fitter_status::ERROR_OTHER;
}

template <size_type D, typename shape_t>
TRACCC_HOST_DEVICE inline void smoothe(
TRACCC_HOST_DEVICE [[nodiscard]] inline kalman_fitter_status smoothe(
track_state<algebra_t>& cur_state,
const track_state<algebra_t>& next_state) const {
const auto meas = cur_state.get_measurement();
Expand Down Expand Up @@ -105,6 +108,21 @@ struct gain_matrix_smoother {
cur_state.smoothed().set_vector(smt_vec);
cur_state.smoothed().set_covariance(smt_cov);

// Return false if track is parallel to z-axis or phi is not finite
const scalar theta = cur_state.smoothed().theta();

if (theta <= 0.f || theta >= constant<traccc::scalar>::pi) {
return kalman_fitter_status::ERROR_THETA_ZERO;
}

if (!std::isfinite(cur_state.smoothed().phi())) {
return kalman_fitter_status::ERROR_INVERSION;
}

if (std::abs(cur_state.smoothed().qop()) == 0.f) {
return kalman_fitter_status::ERROR_QOP_ZERO;
}

// Wrap the phi in the range of [-pi, pi]
wrap_phi(cur_state.smoothed());

Expand All @@ -131,7 +149,7 @@ struct gain_matrix_smoother {

cur_state.smoothed_chi2() = getter::element(chi2, 0, 0);

return;
return kalman_fitter_status::SUCCESS;
}
};

Expand Down
23 changes: 16 additions & 7 deletions core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "traccc/definitions/qualifiers.hpp"
#include "traccc/definitions/track_parametrization.hpp"
#include "traccc/edm/track_state.hpp"
#include "traccc/fitting/status_codes.hpp"

// Detray inlcude(s)
#include <detray/geometry/shapes/line.hpp>
Expand Down Expand Up @@ -40,7 +41,7 @@ struct gain_matrix_updater {
///
/// @return true if the update succeeds
template <typename mask_group_t, typename index_t>
TRACCC_HOST_DEVICE inline bool operator()(
TRACCC_HOST_DEVICE [[nodiscard]] inline kalman_fitter_status operator()(
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
track_state<algebra_t>& trk_state,
const bound_track_parameters& bound_params) const {
Expand All @@ -49,7 +50,7 @@ struct gain_matrix_updater {

const auto D = trk_state.get_measurement().meas_dim;
assert(D == 1u || D == 2u);
bool result = false;
kalman_fitter_status result = kalman_fitter_status::ERROR_OTHER;
switch (D) {
case 1u:
result = update<1u, shape_type>(trk_state, bound_params);
Expand All @@ -64,7 +65,7 @@ struct gain_matrix_updater {
}

template <size_type D, typename shape_t>
TRACCC_HOST_DEVICE inline bool update(
TRACCC_HOST_DEVICE [[nodiscard]] inline kalman_fitter_status update(
track_state<algebra_t>& trk_state,
const bound_track_parameters& bound_params) const {

Expand Down Expand Up @@ -128,9 +129,17 @@ struct gain_matrix_updater {

// Return false if track is parallel to z-axis or phi is not finite
const scalar theta = bound_params.theta();
if (theta <= 0.f || theta >= constant<traccc::scalar>::pi ||
!std::isfinite(bound_params.phi())) {
return false;

if (theta <= 0.f || theta >= constant<traccc::scalar>::pi) {
return kalman_fitter_status::ERROR_THETA_ZERO;
}

if (!std::isfinite(bound_params.phi())) {
return kalman_fitter_status::ERROR_INVERSION;
}

if (std::abs(bound_params.qop()) == 0.f) {
return kalman_fitter_status::ERROR_QOP_ZERO;
}

// Set the track state parameters
Expand All @@ -141,7 +150,7 @@ struct gain_matrix_updater {
// Wrap the phi in the range of [-pi, pi]
wrap_phi(trk_state.filtered());

return true;
return kalman_fitter_status::SUCCESS;
}
};

Expand Down
5 changes: 3 additions & 2 deletions core/include/traccc/fitting/kalman_filter/kalman_actor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "traccc/edm/track_state.hpp"
#include "traccc/fitting/kalman_filter/gain_matrix_updater.hpp"
#include "traccc/fitting/kalman_filter/two_filters_smoother.hpp"
#include "traccc/fitting/status_codes.hpp"
#include "traccc/utils/particle.hpp"

// detray include(s).
Expand Down Expand Up @@ -134,7 +135,7 @@ struct kalman_actor : detray::actor {
// Run Kalman Gain Updater
const auto sf = navigation.get_surface();

bool res = false;
kalman_fitter_status res = kalman_fitter_status::SUCCESS;

if (!actor_state.backward_mode) {
// Forward filter
Expand All @@ -153,7 +154,7 @@ struct kalman_actor : detray::actor {
}

// Abort if the Kalman update fails
if (!res) {
if (res != kalman_fitter_status::SUCCESS) {
propagation._heartbeat &= navigation.abort();
return;
}
Expand Down
44 changes: 35 additions & 9 deletions core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "traccc/fitting/kalman_filter/kalman_step_aborter.hpp"
#include "traccc/fitting/kalman_filter/statistics_updater.hpp"
#include "traccc/fitting/kalman_filter/two_filters_smoother.hpp"
#include "traccc/fitting/status_codes.hpp"
#include "traccc/utils/particle.hpp"

// detray include(s).
Expand Down Expand Up @@ -138,8 +139,8 @@ class kalman_fitter {
/// @param seed_params seed track parameter
/// @param fitter_state the state of kalman fitter
template <typename seed_parameters_t>
TRACCC_HOST_DEVICE void fit(const seed_parameters_t& seed_params,
state& fitter_state) {
TRACCC_HOST_DEVICE [[nodiscard]] kalman_fitter_status fit(
const seed_parameters_t& seed_params, state& fitter_state) {

// Run the kalman filtering for a given number of iterations
for (std::size_t i = 0; i < m_cfg.n_iterations; i++) {
Expand All @@ -155,8 +156,14 @@ class kalman_fitter {
inflate_covariance(seed_params_cpy,
m_cfg.covariance_inflation_factor);

filter(seed_params_cpy, fitter_state);
if (kalman_fitter_status res =
filter(seed_params_cpy, fitter_state);
res != kalman_fitter_status::SUCCESS) {
return res;
}
}

return kalman_fitter_status::SUCCESS;
}

/// Run the kalman fitter for an iteration
Expand All @@ -166,8 +173,8 @@ class kalman_fitter {
/// @param seed_params seed track parameter
/// @param fitter_state the state of kalman fitter
template <typename seed_parameters_t>
TRACCC_HOST_DEVICE void filter(const seed_parameters_t& seed_params,
state& fitter_state) {
TRACCC_HOST_DEVICE [[nodiscard]] kalman_fitter_status filter(
const seed_parameters_t& seed_params, state& fitter_state) {

// Create propagator
propagator_type propagator(m_cfg.propagation);
Expand Down Expand Up @@ -198,10 +205,15 @@ class kalman_fitter {
propagator.propagate(propagation, fitter_state());

// Run smoothing
smooth(fitter_state);
if (kalman_fitter_status res = smooth(fitter_state);
res != kalman_fitter_status::SUCCESS) {
return res;
}

// Update track fitting qualities
update_statistics(fitter_state);

return kalman_fitter_status::SUCCESS;
}

/// Run smoothing after kalman filtering
Expand All @@ -210,7 +222,8 @@ class kalman_fitter {
/// track and vertex fitting", R.Frühwirth, NIM A.
///
/// @param fitter_state the state of kalman fitter
TRACCC_HOST_DEVICE void smooth(state& fitter_state) {
TRACCC_HOST_DEVICE [[nodiscard]] kalman_fitter_status smooth(
state& fitter_state) {

auto& track_states = fitter_state.m_fit_actor_state.m_track_states;

Expand Down Expand Up @@ -251,6 +264,13 @@ class kalman_fitter {
detray::navigation::direction::e_backward);
fitter_state.m_fit_actor_state.backward_mode = true;

const auto& dir = propagation._stepping().dir();
if (dir[0] == 0.f && dir[1] == 0.f) {
// Particle is exactly parallel to the beampipe, which we
// cannot represent.
return kalman_fitter_status::ERROR_THETA_ZERO;
}

propagator.propagate(propagation,
fitter_state.backward_actor_state());

Expand All @@ -266,10 +286,16 @@ class kalman_fitter {

const detray::tracking_surface sf{m_detector,
it->surface_link()};
sf.template visit_mask<gain_matrix_smoother<algebra_type>>(
*it, *(it - 1));
if (kalman_fitter_status res =
sf.template visit_mask<
gain_matrix_smoother<algebra_type>>(*it, *(it - 1));
res != kalman_fitter_status::SUCCESS) {
return res;
}
}
}

return kalman_fitter_status::SUCCESS;
}

TRACCC_HOST_DEVICE
Expand Down
Loading
Loading