Skip to content

Commit 9e33a72

Browse files
committed
Do not copy the bound track parameters for the gain matrix updater
1 parent c75240f commit 9e33a72

File tree

4 files changed

+11
-13
lines changed

4 files changed

+11
-13
lines changed

core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp

+5-9
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct gain_matrix_updater {
3939
TRACCC_HOST_DEVICE inline bool operator()(
4040
const mask_group_t& /*mask_group*/, const index_t& /*index*/,
4141
track_state<algebra_t>& trk_state,
42-
bound_track_parameters& bound_params) const {
42+
const bound_track_parameters& bound_params) const {
4343

4444
using shape_type = typename mask_group_t::value_type::shape;
4545

@@ -57,7 +57,7 @@ struct gain_matrix_updater {
5757
template <size_type D, typename shape_t>
5858
TRACCC_HOST_DEVICE inline bool update(
5959
track_state<algebra_t>& trk_state,
60-
bound_track_parameters& bound_params) const {
60+
const bound_track_parameters& bound_params) const {
6161

6262
static_assert(((D == 1u) || (D == 2u)),
6363
"The measurement dimension should be 1 or 2");
@@ -123,25 +123,21 @@ struct gain_matrix_updater {
123123
const matrix_type<1, 1> chi2 = matrix_operator().transpose(residual) *
124124
matrix_operator().inverse(R) * residual;
125125

126-
// Set the stepper parameter
127-
bound_params.set_vector(filtered_vec);
128-
bound_params.set_covariance(filtered_cov);
129-
130126
// Return false if track is parallel to z-axis or phi is not finite
131127
const scalar theta = bound_params.theta();
132128
if (theta <= 0.f || theta >= constant<traccc::scalar>::pi ||
133129
!std::isfinite(bound_params.phi())) {
134130
return false;
135131
}
136132

137-
// Wrap the phi in the range of [-pi, pi]
138-
wrap_phi(bound_params);
139-
140133
// Set the track state parameters
141134
trk_state.filtered().set_vector(filtered_vec);
142135
trk_state.filtered().set_covariance(filtered_cov);
143136
trk_state.filtered_chi2() = matrix_operator().element(chi2, 0, 0);
144137

138+
// Wrap the phi in the range of [-pi, pi]
139+
wrap_phi(trk_state.filtered());
140+
145141
return true;
146142
}
147143
};

core/include/traccc/fitting/kalman_filter/kalman_actor.hpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct kalman_actor : detray::actor {
7878
TRACCC_HOST_DEVICE void operator()(state& actor_state,
7979
propagator_state_t& propagation) const {
8080

81-
const auto& stepping = propagation._stepping;
81+
auto& stepping = propagation._stepping;
8282
auto& navigation = propagation._navigation;
8383

8484
// If the iterator reaches the end, terminate the propagation
@@ -114,6 +114,9 @@ struct kalman_actor : detray::actor {
114114
return;
115115
}
116116

117+
// Update the propagation flow
118+
stepping.bound_params() = trk_state.filtered();
119+
117120
// Set full jacobian
118121
trk_state.jacobian() = stepping.full_jacobian();
119122

core/src/finding/find_tracks.hpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,9 @@ track_candidate_container_types::host find_tracks(
239239
track_state<algebra_type> trk_state(meas);
240240

241241
// Run the Kalman update on a copy of the track parameters
242-
bound_track_parameters bound_param(in_param);
243242
const bool res =
244243
sf.template visit_mask<gain_matrix_updater<algebra_type>>(
245-
trk_state, bound_param);
244+
trk_state, in_param);
246245

247246
// The chi2 from Kalman update should be less than chi2_max
248247
if (res && trk_state.filtered_chi2() < config.chi2_max) {

device/common/include/traccc/finding/device/impl/find_tracks.ipp

+1-1
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ TRACCC_DEVICE inline void find_tracks(
182182
owner_local_thread_id +
183183
thread_id.getBlockDimX() * thread_id.getBlockIdX();
184184
assert(in_params_liveness.at(owner_global_thread_id) != 0u);
185-
bound_track_parameters in_par =
185+
const bound_track_parameters& in_par =
186186
in_params.at(owner_global_thread_id);
187187
const unsigned int meas_idx =
188188
shared_payload.shared_candidates[thread_id.getLocalThreadIdX()]

0 commit comments

Comments
 (0)