Skip to content

Commit d51a9ba

Browse files
committed
Record the fitting outcome
1 parent ce80a7e commit d51a9ba

File tree

8 files changed

+95
-14
lines changed

8 files changed

+95
-14
lines changed

core/include/traccc/edm/track_state.hpp

+27
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,22 @@
1919

2020
namespace traccc {
2121

22+
enum class fitter_outcome : uint32_t {
23+
UNKNOWN,
24+
SUCCESS,
25+
FAILURE_NON_POSITIVE_NDF,
26+
FAILURE_NOT_ALL_SMOOTHED,
27+
MAX_OUTCOME
28+
};
29+
2230
/// Fitting result per track
2331
template <typename algebra_t>
2432
struct fitting_result {
2533
using scalar_type = detray::dscalar<algebra_t>;
2634

35+
/// Fitting outcome
36+
fitter_outcome fit_outcome = fitter_outcome::UNKNOWN;
37+
2738
/// Fitted track parameter
2839
detray::bound_track_parameters<algebra_t> fit_params;
2940

@@ -205,6 +216,7 @@ struct track_state {
205216

206217
public:
207218
bool is_hole{true};
219+
bool is_smoothed{false};
208220

209221
private:
210222
detray::geometry::barcode m_surface_link;
@@ -227,4 +239,19 @@ using track_state_container_types =
227239
container_types<fitting_result<default_algebra>,
228240
track_state<default_algebra>>;
229241

242+
inline std::size_t count_fitted_tracks(
243+
const track_state_container_types::host& track_states) {
244+
245+
const std::size_t n_tracks = track_states.size();
246+
std::size_t n_fitted_tracks = 0u;
247+
248+
for (std::size_t i = 0; i < n_tracks; i++) {
249+
if (track_states.at(i).header.fit_outcome == fitter_outcome::SUCCESS) {
250+
n_fitted_tracks++;
251+
}
252+
}
253+
254+
return n_fitted_tracks;
255+
}
256+
230257
} // namespace traccc

core/include/traccc/fitting/kalman_filter/gain_matrix_smoother.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ struct gain_matrix_smoother {
148148
matrix::transpose(residual) * matrix::inverse(R) * residual;
149149

150150
cur_state.smoothed_chi2() = getter::element(chi2, 0, 0);
151+
cur_state.is_smoothed = true;
151152

152153
return kalman_fitter_status::SUCCESS;
153154
}

core/include/traccc/fitting/kalman_filter/kalman_fitter.hpp

+39-2
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,11 @@ class kalman_fitter {
148148
// Reset the iterator of kalman actor
149149
fitter_state.m_fit_actor_state.reset();
150150

151+
// TODO: For multiple iterations, seed parameter should be set to
152+
// the first track state which has either filtered or smoothed
153+
// state. If the first track state is a hole, we need to back
154+
// extrapolate from the filtered or smoothed state of next valid
155+
// track state.
151156
auto seed_params_cpy =
152157
(i == 0) ? seed_params
153158
: fitter_state.m_fit_actor_state.m_track_states[0]
@@ -161,6 +166,8 @@ class kalman_fitter {
161166
res != kalman_fitter_status::SUCCESS) {
162167
return res;
163168
}
169+
170+
check_fitting_result(fitter_state);
164171
}
165172

166173
return kalman_fitter_status::SUCCESS;
@@ -303,8 +310,13 @@ class kalman_fitter {
303310
auto& fit_res = fitter_state.m_fit_res;
304311
auto& track_states = fitter_state.m_fit_actor_state.m_track_states;
305312

306-
// Fit parameter = smoothed track parameter at the first surface
307-
fit_res.fit_params = track_states[0].smoothed();
313+
// Fit parameter = smoothed track parameter of the first smoothed track
314+
// state
315+
for (const auto& st : track_states) {
316+
if (st.is_smoothed) {
317+
fit_res.fit_params = st.smoothed();
318+
}
319+
}
308320

309321
for (const auto& trk_state : track_states) {
310322

@@ -321,6 +333,31 @@ class kalman_fitter {
321333
fit_res.n_holes = fitter_state.m_fit_actor_state.n_holes;
322334
}
323335

336+
TRACCC_HOST_DEVICE
337+
void check_fitting_result(state& fitter_state) {
338+
auto& fit_res = fitter_state.m_fit_res;
339+
const auto& track_states =
340+
fitter_state.m_fit_actor_state.m_track_states;
341+
342+
// NDF should always be positive for fitting
343+
if (fit_res.ndf > 0) {
344+
for (const auto& trk_state : track_states) {
345+
// Fitting fails if any of non-hole track states is not smoothed
346+
if (!trk_state.is_hole && !trk_state.is_smoothed) {
347+
fit_res.fit_outcome =
348+
fitter_outcome::FAILURE_NOT_ALL_SMOOTHED;
349+
return;
350+
}
351+
}
352+
353+
// Fitting succeeds if any of non-hole track states is not smoothed
354+
fit_res.fit_outcome = fitter_outcome::SUCCESS;
355+
}
356+
357+
fit_res.fit_outcome = fitter_outcome::FAILURE_NON_POSITIVE_NDF;
358+
return;
359+
}
360+
324361
private:
325362
// Detector object
326363
const detector_type& m_detector;

core/include/traccc/fitting/kalman_filter/statistics_updater.hpp

+7-5
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ struct statistics_updater {
3838
// Measurement dimension
3939
const unsigned int D = trk_state.get_measurement().meas_dim;
4040

41-
// NDoF = NDoF + number of coordinates per measurement
42-
fit_res.ndf += static_cast<scalar_type>(D);
43-
44-
// total_chi2 = total_chi2 + chi2
4541
if (use_backward_filter) {
46-
fit_res.chi2 += trk_state.backward_chi2();
42+
if (trk_state.is_smoothed) {
43+
// NDoF = NDoF + number of coordinates per measurement
44+
fit_res.ndf += static_cast<scalar_type>(D);
45+
fit_res.chi2 += trk_state.backward_chi2();
46+
}
4747
} else {
48+
// NDoF = NDoF + number of coordinates per measurement
49+
fit_res.ndf += static_cast<scalar_type>(D);
4850
fit_res.chi2 += trk_state.filtered_chi2();
4951
}
5052
}

core/include/traccc/fitting/kalman_filter/two_filters_smoother.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ struct two_filters_smoother {
169169
// Wrap the phi in the range of [-pi, pi]
170170
wrap_phi(bound_params);
171171

172+
trk_state.is_smoothed = true;
172173
return kalman_fitter_status::SUCCESS;
173174
}
174175
};

examples/run/cpu/truth_finding_example.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,9 @@ int seq_run(const traccc::opts::track_finding& finding_opts,
150150
auto track_states =
151151
host_fitting(detector, field, traccc::get_data(track_candidates));
152152

153-
std::cout << "Number of fitted tracks: " << track_states.size()
154-
<< std::endl;
153+
std::cout << "Number of fitted tracks: ( "
154+
<< count_fitted_tracks(track_states) << " / "
155+
<< track_states.size() << " ) " << std::endl;
155156

156157
const std::size_t n_fitted_tracks = track_states.size();
157158

examples/run/cpu/truth_fitting_example.cpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,9 @@ int main(int argc, char* argv[]) {
130130
auto track_states = host_fitting(
131131
host_det, field, traccc::get_data(truth_track_candidates));
132132

133-
std::cout << "Number of fitted tracks: " << track_states.size()
134-
<< std::endl;
133+
std::cout << "Number of fitted tracks: ( "
134+
<< count_fitted_tracks(track_states) << " / "
135+
<< track_states.size() << " ) " << std::endl;
135136

136137
const decltype(track_states)::size_type n_fitted_tracks =
137138
track_states.size();

performance/include/traccc/resolution/fitting_performance_writer.hpp

+14-3
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,17 @@ class fitting_performance_writer {
5858
const fitting_result<traccc::default_algebra>& fit_res,
5959
const detector_t& det, event_data& evt_data) {
6060

61+
if (fit_res.fit_outcome != fitter_outcome::SUCCESS) {
62+
return;
63+
}
64+
65+
// Get the first smoothed track state
66+
const auto& trk_state = *std::find_if(
67+
track_states_per_track.begin(), track_states_per_track.end(),
68+
[](const auto& st) { return st.is_smoothed; });
69+
assert(!trk_state.is_hole);
70+
assert(trk_state.is_smoothed);
71+
6172
std::map<measurement, std::map<particle, std::size_t>> meas_to_ptc_map;
6273
std::map<measurement, std::pair<point3, point3>> meas_to_param_map;
6374

@@ -69,8 +80,6 @@ class fitting_performance_writer {
6980
meas_to_param_map = evt_data.m_meas_to_param_map;
7081
}
7182

72-
// Get the track state at the first surface
73-
const auto& trk_state = track_states_per_track[0];
7483
const measurement meas = trk_state.get_measurement();
7584

7685
// Find the contributing particle
@@ -103,7 +112,9 @@ class fitting_performance_writer {
103112
if (fit_res.ndf > 0 && !trk_state.is_hole) {
104113
write_res(truth_param, trk_state.smoothed(), ptc);
105114
}
106-
write_stat(fit_res, track_states_per_track);
115+
if (fit_res.fit_outcome == fitter_outcome::SUCCESS) {
116+
write_stat(fit_res, track_states_per_track);
117+
}
107118
}
108119

109120
/// Writing caches into the file

0 commit comments

Comments
 (0)