Skip to content

Commit 0fd99c0

Browse files
committed
Inflate covariance of seed params in the CKF
1 parent 5a2372c commit 0fd99c0

File tree

7 files changed

+55
-17
lines changed

7 files changed

+55
-17
lines changed

core/include/traccc/edm/track_parameters.hpp

+17-12
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,25 @@ inline void wrap_phi(bound_track_parameters& param) {
4848
param.set_phi(phi);
4949
}
5050

51-
/// Covariance inflation used for track fitting
52-
TRACCC_HOST_DEVICE
53-
inline void inflate_covariance(bound_track_parameters& param,
54-
const traccc::scalar inf_fac) {
55-
auto& cov = param.covariance();
56-
for (unsigned int i = 0; i < e_bound_size; i++) {
57-
for (unsigned int j = 0; j < e_bound_size; j++) {
58-
if (i == j) {
59-
getter::element(cov, i, i) *= inf_fac;
60-
} else {
61-
getter::element(cov, i, j) = 0.f;
51+
/// Struct for covariance inflation
52+
struct covariance_inflator {
53+
54+
traccc::scalar inf_fac{1.f};
55+
56+
/// Covariance inflation used for track fitting
57+
TRACCC_HOST_DEVICE
58+
inline void operator()(bound_track_parameters& param) {
59+
auto& cov = param.covariance();
60+
for (unsigned int i = 0; i < e_bound_size; i++) {
61+
for (unsigned int j = 0; j < e_bound_size; j++) {
62+
if (i == j) {
63+
getter::element(cov, i, i) *= inf_fac;
64+
} else {
65+
getter::element(cov, i, j) = 0.f;
66+
}
6267
}
6368
}
6469
}
65-
}
70+
};
6671

6772
} // namespace traccc

core/include/traccc/finding/details/find_tracks.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,10 @@ track_candidate_container_types::host find_tracks(
139139
// Copy seed to input parameters
140140
std::vector<bound_track_parameters> in_params(seeds.size());
141141
std::copy(seeds.begin(), seeds.end(), in_params.begin());
142+
143+
// Inflate the covariance
144+
covariance_inflator cov_inf(config.covariance_inflation_factor);
145+
std::for_each(in_params.begin(), in_params.end(), cov_inf);
142146
std::vector<unsigned int> n_trks_per_seed(seeds.size());
143147

144148
std::vector<bound_track_parameters> out_params;

core/include/traccc/finding/finding_config.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@ struct finding_config {
4848
/// Particle hypothesis
4949
detray::pdg_particle<traccc::scalar> ptc_hypothesis =
5050
detray::muon<traccc::scalar>();
51+
52+
/// Covariance inflation
53+
traccc::scalar covariance_inflation_factor = 1e3f;
5154
};
5255

5356
} // namespace traccc

core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ class kalman_fitter {
153153
: fitter_state.m_fit_actor_state.m_track_states[0]
154154
.smoothed();
155155

156-
inflate_covariance(seed_params_cpy,
157-
m_cfg.covariance_inflation_factor);
156+
covariance_inflator cov_inf(m_cfg.covariance_inflation_factor);
157+
cov_inf(seed_params_cpy);
158158

159159
filter(seed_params_cpy, fitter_state);
160160
}
@@ -234,8 +234,8 @@ class kalman_fitter {
234234
typename backward_propagator_type::state propagation(
235235
last.smoothed(), m_field, m_detector);
236236

237-
inflate_covariance(propagation._stepping.bound_params(),
238-
m_cfg.covariance_inflation_factor);
237+
covariance_inflator cov_inf(m_cfg.covariance_inflation_factor);
238+
cov_inf(propagation._stepping.bound_params());
239239

240240
propagation._navigation.set_volume(
241241
last.smoothed().surface_link().volume());

device/cuda/src/finding/finding_algorithm.cu

+8
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,14 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
154154
m_copy.setup(in_params_buffer)->ignore();
155155
m_copy(vecmem::get_data(seeds_buffer), vecmem::get_data(in_params_buffer))
156156
->ignore();
157+
158+
// Inflate the covariance
159+
bound_track_parameters_collection_types::device in_params_device(
160+
in_params_buffer);
161+
covariance_inflator cov_inf(m_cfg.covariance_inflation_factor);
162+
thrust::for_each(thrust_policy, in_params_device.begin(),
163+
in_params_device.end(), cov_inf);
164+
157165
vecmem::data::vector_buffer<unsigned int> param_liveness_buffer(n_seeds,
158166
m_mr.main);
159167
m_copy.setup(param_liveness_buffer)->ignore();

device/sycl/src/finding/find_tracks.hpp

+11
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,17 @@ track_candidate_container_types::buffer find_tracks(
172172
mr.main);
173173
copy.setup(in_params_buffer)->wait();
174174
copy(seeds, in_params_buffer, vecmem::copy::type::device_to_device)->wait();
175+
176+
// Inflate the covariance
177+
bound_track_parameters_collection_types::device in_params_device(
178+
in_params_buffer);
179+
oneapi::dpl::for_each(
180+
policy, in_params_device.begin(), in_params_device.end(),
181+
[config](bound_track_parameters& params) {
182+
covariance_inflator cov_inf(config.covariance_inflation_factor);
183+
cov_inf(params);
184+
});
185+
175186
vecmem::data::vector_buffer<unsigned int> param_liveness_buffer(n_seeds,
176187
mr.main);
177188
copy.setup(param_liveness_buffer)->wait();

examples/options/src/track_finding.cpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,11 @@ track_finding::track_finding() : interface("Track Finding Options") {
6969
m_desc.add_options()("particle-hypothesis",
7070
po::value(&m_pdg_number)->default_value(m_pdg_number),
7171
"PDG number for the particle hypothesis");
72+
m_desc.add_options()(
73+
"find-covariance-inflation-factor",
74+
po::value(&m_config.covariance_inflation_factor)
75+
->default_value(m_config.covariance_inflation_factor),
76+
"Covariance inflation factor for the track finding");
7277
}
7378

7479
track_finding::operator finding_config() const {
@@ -111,7 +116,9 @@ std::unique_ptr<configuration_printable> track_finding::as_printable() const {
111116
std::to_string(m_config.max_num_skipping_per_cand)));
112117
cat->add_child(std::make_unique<configuration_kv_pair>(
113118
"PDG number", std::to_string(m_pdg_number)));
114-
119+
cat->add_child(std::make_unique<configuration_kv_pair>(
120+
"Covariance inflation factor",
121+
std::to_string(m_config.covariance_inflation_factor)));
115122
return cat;
116123
}
117124
} // namespace traccc::opts

0 commit comments

Comments
 (0)