Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backward filter for Two-filter Smoothing #788

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ traccc_add_library( traccc_core core TYPE SHARED
"include/traccc/fitting/kalman_filter/kalman_fitter.hpp"
"include/traccc/fitting/kalman_filter/kalman_step_aborter.hpp"
"include/traccc/fitting/kalman_filter/statistics_updater.hpp"
"include/traccc/fitting/kalman_filter/two_filters_smoother.hpp"
"include/traccc/fitting/details/fit_tracks.hpp"
"include/traccc/fitting/kalman_fitting_algorithm.hpp"
"src/fitting/kalman_fitting_algorithm.cpp"
Expand Down
16 changes: 16 additions & 0 deletions core/include/traccc/edm/track_parameters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,20 @@ inline void wrap_phi(bound_track_parameters& param) {
param.set_phi(phi);
}

/// Covariance inflation used for track fitting
TRACCC_HOST_DEVICE
inline void inflate_covariance(bound_track_parameters& param,
const traccc::scalar inf_fac) {
auto& cov = param.covariance();
for (unsigned int i = 0; i < e_bound_size; i++) {
for (unsigned int j = 0; j < e_bound_size; j++) {
if (i == j) {
getter::element(cov, i, i) *= inf_fac;
} else {
getter::element(cov, i, j) = 0.f;
}
}
}
}

} // namespace traccc
17 changes: 17 additions & 0 deletions core/include/traccc/edm/track_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ struct fitting_result {
// The number of holes (The number of sensitive surfaces which do not have a
// measurement for the track pattern)
unsigned int n_holes{0u};

/// Reset the statistics
TRACCC_HOST_DEVICE
void reset_statistics() {
ndf = 0.f;
chi2 = 0.f;
n_holes = 0u;
}
};

/// Fitting result per measurement
Expand Down Expand Up @@ -160,6 +168,14 @@ struct track_state {
TRACCC_HOST_DEVICE
inline const scalar_type& filtered_chi2() const { return m_filtered_chi2; }

/// @return the non-const chi square of backward filter
TRACCC_HOST_DEVICE
inline scalar_type& backward_chi2() { return m_backward_chi2; }

/// @return the const chi square of backward filter
TRACCC_HOST_DEVICE
inline scalar_type backward_chi2() const { return m_backward_chi2; }

/// @return the non-const filtered parameter
TRACCC_HOST_DEVICE
inline bound_track_parameters_type& filtered() { return m_filtered; }
Expand Down Expand Up @@ -200,6 +216,7 @@ struct track_state {
bound_track_parameters_type m_filtered;
scalar_type m_smoothed_chi2 = 0.f;
bound_track_parameters_type m_smoothed;
scalar_type m_backward_chi2 = 0.f;
};

/// Declare all track_state collection types
Expand Down
4 changes: 4 additions & 0 deletions core/include/traccc/fitting/fitting_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ struct fitting_config {
/// Particle hypothesis
detray::pdg_particle<traccc::scalar> ptc_hypothesis =
detray::muon<traccc::scalar>();

/// Smoothing with backward filter
bool use_backward_filter = false;
traccc::scalar covariance_inflation_factor = 1e3f;
};

} // namespace traccc
68 changes: 53 additions & 15 deletions core/include/traccc/fitting/kalman_filter/kalman_actor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "traccc/definitions/qualifiers.hpp"
#include "traccc/edm/track_state.hpp"
#include "traccc/fitting/kalman_filter/gain_matrix_updater.hpp"
#include "traccc/fitting/kalman_filter/two_filters_smoother.hpp"
#include "traccc/utils/particle.hpp"

// detray include(s).
Expand All @@ -33,31 +34,50 @@ struct kalman_actor : detray::actor {
state(vector_t<track_state_type>&& track_states)
: m_track_states(std::move(track_states)) {
m_it = m_track_states.begin();
m_it_rev = m_track_states.rbegin();
}

/// Constructor with the vector of track states
TRACCC_HOST_DEVICE
state(const vector_t<track_state_type>& track_states)
: m_track_states(track_states) {
m_it = m_track_states.begin();
m_it_rev = m_track_states.rbegin();
}

/// @return the reference of track state pointed by the iterator
TRACCC_HOST_DEVICE
track_state_type& operator()() { return *m_it; }
track_state_type& operator()() {
if (!backward_mode) {
return *m_it;
} else {
return *m_it_rev;
}
}

/// Reset the iterator
TRACCC_HOST_DEVICE
void reset() { m_it = m_track_states.begin(); }
void reset() {
m_it = m_track_states.begin();
m_it_rev = m_track_states.rbegin();
}

/// Advance the iterator
TRACCC_HOST_DEVICE
void next() { m_it++; }
void next() {
if (!backward_mode) {
m_it++;
} else {
m_it_rev++;
}
}

/// @return true if the iterator reaches the end of vector
TRACCC_HOST_DEVICE
bool is_complete() const {
if (m_it == m_track_states.end()) {
bool is_complete() {
if (!backward_mode && m_it == m_track_states.end()) {
return true;
} else if (backward_mode && m_it_rev == m_track_states.rend()) {
return true;
}
return false;
Expand All @@ -69,9 +89,15 @@ struct kalman_actor : detray::actor {
// iterator for forward filtering
typename vector_t<track_state_type>::iterator m_it;

// iterator for backward filtering
typename vector_t<track_state_type>::reverse_iterator m_it_rev;

// The number of holes (The number of sensitive surfaces which do not
// have a measurement for the track pattern)
unsigned int n_holes{0u};

// Run back filtering for smoothing, if true
bool backward_mode = false;
};

/// Actor operation to perform the Kalman filtering
Expand Down Expand Up @@ -99,32 +125,44 @@ struct kalman_actor : detray::actor {
// Increase the hole counts if the propagator fails to find the next
// measurement
if (navigation.barcode() != trk_state.surface_link()) {
actor_state.n_holes++;
if (!actor_state.backward_mode) {
actor_state.n_holes++;
}
return;
}

// This track state is not a hole
trk_state.is_hole = false;
if (!actor_state.backward_mode) {
trk_state.is_hole = false;
}

// Run Kalman Gain Updater
const auto sf = navigation.get_surface();

const bool res =
sf.template visit_mask<gain_matrix_updater<algebra_t>>(
bool res = false;

if (!actor_state.backward_mode) {
// Forward filter
res = sf.template visit_mask<gain_matrix_updater<algebra_t>>(
trk_state, propagation._stepping.bound_params());

// Update the propagation flow
stepping.bound_params() = trk_state.filtered();

// Set full jacobian
trk_state.jacobian() = stepping.full_jacobian();
} else {
// Backward filter for smoothing
res = sf.template visit_mask<two_filters_smoother<algebra_t>>(
trk_state, propagation._stepping.bound_params());
}

// Abort if the Kalman update fails
if (!res) {
propagation._heartbeat &= navigation.abort();
return;
}

// Update the propagation flow
stepping.bound_params() = trk_state.filtered();

// Set full jacobian
trk_state.jacobian() = stepping.full_jacobian();

// Change the charge of hypothesized particles when the sign of qop
// is changed (This rarely happens when qop is set with a poor seed
// resolution)
Expand Down
94 changes: 69 additions & 25 deletions core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "traccc/fitting/kalman_filter/kalman_actor.hpp"
#include "traccc/fitting/kalman_filter/kalman_step_aborter.hpp"
#include "traccc/fitting/kalman_filter/statistics_updater.hpp"
#include "traccc/fitting/kalman_filter/two_filters_smoother.hpp"
#include "traccc/utils/particle.hpp"

// detray include(s).
Expand Down Expand Up @@ -67,10 +68,17 @@ class kalman_fitter {
detray::actor_chain<detray::dtuple, aborter, transporter, interactor,
fit_actor, resetter, kalman_step_aborter>;

using backward_actor_chain_type =
detray::actor_chain<detray::dtuple, aborter, transporter, fit_actor,
interactor, resetter, kalman_step_aborter>;

// Propagator type
using propagator_type =
detray::propagator<stepper_t, navigator_t, actor_chain_type>;

using backward_propagator_type =
detray::propagator<stepper_t, navigator_t, backward_actor_chain_type>;

/// Constructor with a detector
///
/// @param det the detector object
Expand Down Expand Up @@ -104,6 +112,14 @@ class kalman_fitter {
m_resetter_state, m_step_aborter_state);
}

/// @return the actor chain state
TRACCC_HOST_DEVICE
typename backward_actor_chain_type::state backward_actor_state() {
return detray::tie(m_aborter_state, m_transporter_state,
m_fit_actor_state, m_interactor_state,
m_resetter_state, m_step_aborter_state);
}

/// Individual actor states
typename aborter::state m_aborter_state{};
typename transporter::state m_transporter_state{};
Expand Down Expand Up @@ -132,17 +148,15 @@ class kalman_fitter {
// Reset the iterator of kalman actor
fitter_state.m_fit_actor_state.reset();

if (i == 0) {
filter(seed_params, fitter_state);
}
// From the second iteration, seed parameter is the smoothed track
// parameter at the first surface
else {
const auto& new_seed_params =
fitter_state.m_fit_actor_state.m_track_states[0].smoothed();
auto seed_params_cpy =
(i == 0) ? seed_params
: fitter_state.m_fit_actor_state.m_track_states[0]
.smoothed();

filter(new_seed_params, fitter_state);
}
inflate_covariance(seed_params_cpy,
m_cfg.covariance_inflation_factor);

filter(seed_params_cpy, fitter_state);
}
}

Expand Down Expand Up @@ -178,6 +192,9 @@ class kalman_fitter {
.template set_constraint<detray::step::constraint::e_accuracy>(
m_cfg.propagation.stepping.step_constraint);

// Reset fitter statistics
fitter_state.m_fit_res.reset_statistics();

// Run forward filtering
propagator.propagate(propagation, fitter_state());

Expand All @@ -194,14 +211,10 @@ class kalman_fitter {
/// track and vertex fitting", R.Frühwirth, NIM A.
///
/// @param fitter_state the state of kalman fitter
TRACCC_HOST_DEVICE
void smooth(state& fitter_state) {
TRACCC_HOST_DEVICE void smooth(state& fitter_state) {

auto& track_states = fitter_state.m_fit_actor_state.m_track_states;

// The smoothing algorithm requires the following:
// (1) the filtered track parameter of the current surface
// (2) the smoothed track parameter of the next surface
//
// Since the smoothed track parameter of the last surface can be
// considered to be the filtered one, we can reversly iterate the
// algorithm to obtain the smoothed parameter of other surfaces
Expand All @@ -210,14 +223,45 @@ class kalman_fitter {
last.smoothed().set_covariance(last.filtered().covariance());
last.smoothed_chi2() = last.filtered_chi2();

for (typename vector_type<track_state<algebra_type>>::reverse_iterator
it = track_states.rbegin() + 1;
it != track_states.rend(); ++it) {
if (m_cfg.use_backward_filter) {
// Backward propagator for the two-filters method
backward_propagator_type propagator(m_cfg.propagation);

// Set path limit
fitter_state.m_aborter_state.set_path_limit(
m_cfg.propagation.stepping.path_limit);

typename backward_propagator_type::state propagation(
last.smoothed(), m_field, m_detector);

inflate_covariance(propagation._stepping.bound_params(),
m_cfg.covariance_inflation_factor);

propagation._navigation.set_volume(
last.smoothed().surface_link().volume());

// Run kalman smoother
const detray::tracking_surface sf{m_detector, it->surface_link()};
sf.template visit_mask<gain_matrix_smoother<algebra_type>>(
*it, *(it - 1));
propagation._navigation.set_direction(
detray::navigation::direction::e_backward);
fitter_state.m_fit_actor_state.backward_mode = true;

propagator.propagate(propagation,
fitter_state.backward_actor_state());

// Reset the backward mode to false
fitter_state.m_fit_actor_state.backward_mode = false;

} else {
// Run the Rauch–Tung–Striebel (RTS) smoother
for (typename vector_type<
track_state<algebra_type>>::reverse_iterator it =
track_states.rbegin() + 1;
it != track_states.rend(); ++it) {

const detray::tracking_surface sf{m_detector,
it->surface_link()};
sf.template visit_mask<gain_matrix_smoother<algebra_type>>(
*it, *(it - 1));
}
}
}

Expand All @@ -233,8 +277,8 @@ class kalman_fitter {

const detray::tracking_surface sf{m_detector,
trk_state.surface_link()};
sf.template visit_mask<statistics_updater<algebra_type>>(fit_res,
trk_state);
sf.template visit_mask<statistics_updater<algebra_type>>(
fit_res, trk_state, m_cfg.use_backward_filter);
}

// Subtract the NDoF with the degree of freedom of the bound track (=5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ struct statistics_updater {
TRACCC_HOST_DEVICE inline void operator()(
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
fitting_result<algebra_t>& fit_res,
const track_state<algebra_t>& trk_state) {
const track_state<algebra_t>& trk_state,
const bool use_backward_filter) {

if (!trk_state.is_hole) {

Expand All @@ -41,7 +42,11 @@ struct statistics_updater {
fit_res.ndf += static_cast<scalar_type>(D);

// total_chi2 = total_chi2 + chi2
fit_res.chi2 += trk_state.smoothed_chi2();
if (use_backward_filter) {
fit_res.chi2 += trk_state.backward_chi2();
} else {
fit_res.chi2 += trk_state.filtered_chi2();
}
}
}
};
Expand Down
Loading
Loading