Skip to content

Commit 715efb8

Browse files
committed
Try not to use the vector container type of the detector in the KF
1 parent 077cc64 commit 715efb8

File tree

10 files changed

+107
-102
lines changed

10 files changed

+107
-102
lines changed

benchmarks/common/benchmarks/toy_detector_benchmark.hpp

-3
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,6 @@
2222
#include "detray/detectors/bfield.hpp"
2323
#include "detray/io/frontend/detector_reader.hpp"
2424
#include "detray/io/frontend/detector_writer.hpp"
25-
#include "detray/navigation/navigator.hpp"
26-
#include "detray/propagator/propagator.hpp"
27-
#include "detray/propagator/rk_stepper.hpp"
2825
#include "detray/test/utils/detectors/build_toy_detector.hpp"
2926
#include "detray/test/utils/simulation/event_generator/track_generators.hpp"
3027
#include "detray/tracks/ray.hpp"

benchmarks/cpu/toy_detector_cpu.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@
2727
#include "detray/core/detector.hpp"
2828
#include "detray/detectors/bfield.hpp"
2929
#include "detray/io/frontend/detector_reader.hpp"
30-
#include "detray/navigation/navigator.hpp"
31-
#include "detray/propagator/propagator.hpp"
32-
#include "detray/propagator/rk_stepper.hpp"
3330

3431
// VecMem include(s).
3532
#include <vecmem/memory/host_memory_resource.hpp>

core/include/traccc/edm/track_state.hpp

-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include "traccc/edm/track_candidate.hpp"
1616

1717
// detray include(s).
18-
#include "detray/navigation/navigator.hpp"
1918
#include "detray/tracks/bound_track_parameters.hpp"
2019

2120
namespace traccc {

core/include/traccc/fitting/details/fit_tracks.hpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,14 @@ track_state_container_types::host fit_tracks(
5858
}
5959

6060
// Make a fitter state
61-
typename fitter_t::state fitter_state(std::move(input_states));
61+
typename fitter_t::state fitter_state(vecmem::get_data(input_states));
6262

6363
// Run the fitter.
6464
fitter.fit(track_candidates.get_headers()[i], fitter_state);
6565

6666
// Save the results into the output container.
67-
result.push_back(
68-
std::move(fitter_state.m_fit_res),
69-
std::move(fitter_state.m_fit_actor_state.m_track_states));
67+
result.push_back(std::move(fitter_state.m_fit_res),
68+
std::move(input_states));
7069
}
7170

7271
// Return the fitted track states.

core/include/traccc/fitting/kalman_filter/gain_matrix_smoother.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#include "traccc/edm/track_parameters.hpp"
1313
#include "traccc/edm/track_state.hpp"
1414

15+
// detray include(s)
16+
#include <detray/geometry/shapes/line.hpp>
17+
1518
namespace traccc {
1619

1720
/// Type unrolling functor to smooth the track parameters after the Kalman

core/include/traccc/fitting/kalman_filter/gain_matrix_updater.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
#include "traccc/definitions/track_parametrization.hpp"
1313
#include "traccc/edm/track_state.hpp"
1414

15+
// detray include(s)
16+
#include <detray/geometry/shapes/line.hpp>
17+
1518
namespace traccc {
1619

1720
/// Type unrolling functor for Kalman updating

core/include/traccc/fitting/kalman_filter/kalman_actor.hpp

+59-60
Original file line numberDiff line numberDiff line change
@@ -15,89 +15,84 @@
1515
#include "traccc/utils/particle.hpp"
1616

1717
// detray include(s).
18-
#include "detray/propagator/base_actor.hpp"
18+
#include <detray/navigation/navigator.hpp>
19+
#include <detray/propagator/base_actor.hpp>
20+
21+
// vecmem include(s)
22+
#include <vecmem/containers/device_vector.hpp>
1923

2024
namespace traccc {
2125

2226
/// Detray actor for Kalman filtering
23-
template <typename algebra_t, template <typename...> class vector_t>
27+
template <typename algebra_t>
2428
struct kalman_actor : detray::actor {
2529

2630
// Type declarations
2731
using track_state_type = track_state<algebra_t>;
2832

2933
// Actor state
3034
struct state {
35+
using iterator_t =
36+
typename vecmem::device_vector<track_state_type>::iterator;
3137

3238
/// Constructor with the vector of track states
3339
TRACCC_HOST_DEVICE
34-
state(vector_t<track_state_type>&& track_states)
35-
: m_track_states(std::move(track_states)) {
36-
m_it = m_track_states.begin();
37-
m_it_rev = m_track_states.rbegin();
40+
state(vecmem::data::vector_view<track_state_type> track_states,
41+
bool backward_mode = false)
42+
: m_begin{track_states.ptr()},
43+
m_end{track_states.ptr() + track_states.size()} {
44+
reset(backward_mode);
3845
}
3946

4047
/// Constructor with the vector of track states
4148
TRACCC_HOST_DEVICE
42-
state(const vector_t<track_state_type>& track_states)
43-
: m_track_states(track_states) {
44-
m_it = m_track_states.begin();
45-
m_it_rev = m_track_states.rbegin();
49+
state(vecmem::device_vector<track_state<algebra_t>>& track_states,
50+
bool backward_mode = false)
51+
: m_begin{track_states.begin()}, m_end{track_states.end()} {
52+
reset(backward_mode);
4653
}
4754

48-
/// @return the reference of track state pointed by the iterator
49-
TRACCC_HOST_DEVICE
50-
track_state_type& operator()() {
51-
if (!backward_mode) {
52-
return *m_it;
53-
} else {
54-
return *m_it_rev;
55-
}
56-
}
55+
/// Get range iterators
56+
/// @{
57+
TRACCC_HOST_DEVICE iterator_t begin() const { return m_begin; }
58+
TRACCC_HOST_DEVICE iterator_t rbegin() const { return m_end; }
59+
TRACCC_HOST_DEVICE iterator_t end() const { return m_end; }
60+
TRACCC_HOST_DEVICE iterator_t rend() const { return m_begin; }
61+
TRACCC_HOST_DEVICE iterator_t current() const { return m_it; }
62+
/// @}
63+
64+
/// @return a reference of the track state pointed to by the iterator
65+
TRACCC_HOST_DEVICE track_state_type& operator()() { return *m_it; }
5766

5867
/// Reset the iterator
5968
TRACCC_HOST_DEVICE
60-
void reset() {
61-
m_it = m_track_states.begin();
62-
m_it_rev = m_track_states.rbegin();
69+
void reset(bool backward_mode = false) {
70+
m_it = backward_mode ? m_begin : m_end;
6371
}
6472

65-
/// Advance the iterator
66-
TRACCC_HOST_DEVICE
67-
void next() {
68-
if (!backward_mode) {
69-
m_it++;
70-
} else {
71-
m_it_rev++;
72-
}
73-
}
73+
/// Move the iterator forward
74+
TRACCC_HOST_DEVICE void next() { ++m_it; }
75+
76+
/// Move the iterator backward
77+
TRACCC_HOST_DEVICE void previous() { --m_it; }
7478

7579
/// @return true if the iterator reaches the end of vector
7680
TRACCC_HOST_DEVICE
77-
bool is_complete() {
78-
if (!backward_mode && m_it == m_track_states.end()) {
79-
return true;
80-
} else if (backward_mode && m_it_rev == m_track_states.rend()) {
81-
return true;
82-
}
83-
return false;
81+
bool is_complete(bool backward_mode = false) {
82+
return (!backward_mode && m_it == m_end) ||
83+
(backward_mode && m_it == m_begin);
8484
}
8585

86-
// vector of track states
87-
vector_t<track_state_type> m_track_states;
88-
89-
// iterator for forward filtering
90-
typename vector_t<track_state_type>::iterator m_it;
91-
92-
// iterator for backward filtering
93-
typename vector_t<track_state_type>::reverse_iterator m_it_rev;
86+
/// First track state of the track
87+
iterator_t m_begin;
88+
/// Last track state of the track
89+
iterator_t m_end;
90+
/// Current track state
91+
iterator_t m_it;
9492

9593
// The number of holes (The number of sensitive surfaces which do not
9694
// have a measurement for the track pattern)
9795
unsigned int n_holes{0u};
98-
99-
// Run back filtering for smoothing, if true
100-
bool backward_mode = false;
10196
};
10297

10398
/// Actor operation to perform the Kalman filtering
@@ -110,9 +105,11 @@ struct kalman_actor : detray::actor {
110105

111106
auto& stepping = propagation._stepping;
112107
auto& navigation = propagation._navigation;
108+
const bool backward_mode{navigation.direction() ==
109+
detray::navigation::direction::e_backward};
113110

114111
// If the iterator reaches the end, terminate the propagation
115-
if (actor_state.is_complete()) {
112+
if (actor_state.is_complete(backward_mode)) {
116113
propagation._heartbeat &= navigation.abort();
117114
return;
118115
}
@@ -125,23 +122,22 @@ struct kalman_actor : detray::actor {
125122
// Increase the hole counts if the propagator fails to find the next
126123
// measurement
127124
if (navigation.barcode() != trk_state.surface_link()) {
128-
if (!actor_state.backward_mode) {
125+
if (!backward_mode) {
129126
actor_state.n_holes++;
130127
}
131128
return;
132129
}
133130

134-
// This track state is not a hole
135-
if (!actor_state.backward_mode) {
136-
trk_state.is_hole = false;
137-
}
138-
139131
// Run Kalman Gain Updater
140132
const auto sf = navigation.get_surface();
141133

142134
bool res = false;
143135

144-
if (!actor_state.backward_mode) {
136+
if (!backward_mode) {
137+
138+
// This track state is not a hole
139+
trk_state.is_hole = false;
140+
145141
// Forward filter
146142
res = sf.template visit_mask<gain_matrix_updater<algebra_t>>(
147143
trk_state, propagation._stepping.bound_params());
@@ -151,10 +147,16 @@ struct kalman_actor : detray::actor {
151147

152148
// Set full jacobian
153149
trk_state.jacobian() = stepping.full_jacobian();
150+
151+
// Update iterator
152+
actor_state.next();
154153
} else {
155154
// Backward filter for smoothing
156155
res = sf.template visit_mask<two_filters_smoother<algebra_t>>(
157156
trk_state, propagation._stepping.bound_params());
157+
158+
// Update iterator
159+
actor_state.previous();
158160
}
159161

160162
// Abort if the Kalman update fails
@@ -170,9 +172,6 @@ struct kalman_actor : detray::actor {
170172
stepping.particle_hypothesis(),
171173
propagation._stepping.bound_params()));
172174

173-
// Update iterator
174-
actor_state.next();
175-
176175
// Flag renavigation of the current candidate
177176
navigation.set_high_trust();
178177
}

0 commit comments

Comments
 (0)