15
15
#include " traccc/utils/particle.hpp"
16
16
17
17
// detray include(s).
18
- #include " detray/propagator/base_actor.hpp"
18
+ #include < detray/navigation/navigator.hpp>
19
+ #include < detray/propagator/base_actor.hpp>
20
+
21
+ // vecmem include(s)
22
+ #include < vecmem/containers/device_vector.hpp>
19
23
20
24
namespace traccc {
21
25
22
26
// / Detray actor for Kalman filtering
23
- template <typename algebra_t , template < typename ...> class vector_t >
27
+ template <typename algebra_t >
24
28
struct kalman_actor : detray::actor {
25
29
26
30
// Type declarations
27
31
using track_state_type = track_state<algebra_t >;
28
32
29
33
// Actor state
30
34
struct state {
35
+ using iterator_t =
36
+ typename vecmem::device_vector<track_state_type>::iterator;
31
37
32
38
// / Constructor with the vector of track states
33
39
TRACCC_HOST_DEVICE
34
- state (vector_t <track_state_type>&& track_states)
35
- : m_track_states(std::move(track_states)) {
36
- m_it = m_track_states.begin ();
37
- m_it_rev = m_track_states.rbegin ();
40
+ state (vecmem::data::vector_view<track_state_type> track_states,
41
+ bool backward_mode = false )
42
+ : m_begin{track_states.ptr ()},
43
+ m_end{track_states.ptr () + track_states.size ()} {
44
+ reset (backward_mode);
38
45
}
39
46
40
47
// / Constructor with the vector of track states
41
48
TRACCC_HOST_DEVICE
42
- state (const vector_t <track_state_type> & track_states)
43
- : m_track_states(track_states) {
44
- m_it = m_track_states .begin ();
45
- m_it_rev = m_track_states. rbegin ( );
49
+ state (vecmem::device_vector<track_state< algebra_t >> & track_states,
50
+ bool backward_mode = false )
51
+ : m_begin{track_states .begin ()}, m_end{track_states. end ()} {
52
+ reset (backward_mode );
46
53
}
47
54
48
- // / @return the reference of track state pointed by the iterator
49
- TRACCC_HOST_DEVICE
50
- track_state_type& operator ()() {
51
- if (!backward_mode) {
52
- return *m_it;
53
- } else {
54
- return *m_it_rev;
55
- }
56
- }
55
+ // / Get range iterators
56
+ // / @{
57
+ TRACCC_HOST_DEVICE iterator_t begin () const { return m_begin; }
58
+ TRACCC_HOST_DEVICE iterator_t rbegin () const { return m_end; }
59
+ TRACCC_HOST_DEVICE iterator_t end () const { return m_end; }
60
+ TRACCC_HOST_DEVICE iterator_t rend () const { return m_begin; }
61
+ TRACCC_HOST_DEVICE iterator_t current () const { return m_it; }
62
+ // / @}
63
+
64
+ // / @return a reference of the track state pointed to by the iterator
65
+ TRACCC_HOST_DEVICE track_state_type& operator ()() { return *m_it; }
57
66
58
67
// / Reset the iterator
59
68
TRACCC_HOST_DEVICE
60
- void reset () {
61
- m_it = m_track_states.begin ();
62
- m_it_rev = m_track_states.rbegin ();
69
+ void reset (bool backward_mode = false ) {
70
+ m_it = backward_mode ? m_begin : m_end;
63
71
}
64
72
65
- // / Advance the iterator
66
- TRACCC_HOST_DEVICE
67
- void next () {
68
- if (!backward_mode) {
69
- m_it++;
70
- } else {
71
- m_it_rev++;
72
- }
73
- }
73
+ // / Move the iterator forward
74
+ TRACCC_HOST_DEVICE void next () { ++m_it; }
75
+
76
+ // / Move the iterator backward
77
+ TRACCC_HOST_DEVICE void previous () { --m_it; }
74
78
75
79
// / @return true if the iterator reaches the end of vector
76
80
TRACCC_HOST_DEVICE
77
- bool is_complete () {
78
- if (!backward_mode && m_it == m_track_states.end ()) {
79
- return true ;
80
- } else if (backward_mode && m_it_rev == m_track_states.rend ()) {
81
- return true ;
82
- }
83
- return false ;
81
+ bool is_complete (bool backward_mode = false ) {
82
+ return (!backward_mode && m_it == m_end) ||
83
+ (backward_mode && m_it == m_begin);
84
84
}
85
85
86
- // vector of track states
87
- vector_t <track_state_type> m_track_states;
88
-
89
- // iterator for forward filtering
90
- typename vector_t <track_state_type>::iterator m_it;
91
-
92
- // iterator for backward filtering
93
- typename vector_t <track_state_type>::reverse_iterator m_it_rev;
86
+ // / First track state of the track
87
+ iterator_t m_begin;
88
+ // / Last track state of the track
89
+ iterator_t m_end;
90
+ // / Current track state
91
+ iterator_t m_it;
94
92
95
93
// The number of holes (The number of sensitive surfaces which do not
96
94
// have a measurement for the track pattern)
97
95
unsigned int n_holes{0u };
98
-
99
- // Run back filtering for smoothing, if true
100
- bool backward_mode = false ;
101
96
};
102
97
103
98
// / Actor operation to perform the Kalman filtering
@@ -110,9 +105,11 @@ struct kalman_actor : detray::actor {
110
105
111
106
auto & stepping = propagation._stepping ;
112
107
auto & navigation = propagation._navigation ;
108
+ const bool backward_mode{navigation.direction () ==
109
+ detray::navigation::direction::e_backward};
113
110
114
111
// If the iterator reaches the end, terminate the propagation
115
- if (actor_state.is_complete ()) {
112
+ if (actor_state.is_complete (backward_mode )) {
116
113
propagation._heartbeat &= navigation.abort ();
117
114
return ;
118
115
}
@@ -125,23 +122,22 @@ struct kalman_actor : detray::actor {
125
122
// Increase the hole counts if the propagator fails to find the next
126
123
// measurement
127
124
if (navigation.barcode () != trk_state.surface_link ()) {
128
- if (!actor_state. backward_mode ) {
125
+ if (!backward_mode) {
129
126
actor_state.n_holes ++;
130
127
}
131
128
return ;
132
129
}
133
130
134
- // This track state is not a hole
135
- if (!actor_state.backward_mode ) {
136
- trk_state.is_hole = false ;
137
- }
138
-
139
131
// Run Kalman Gain Updater
140
132
const auto sf = navigation.get_surface ();
141
133
142
134
bool res = false ;
143
135
144
- if (!actor_state.backward_mode ) {
136
+ if (!backward_mode) {
137
+
138
+ // This track state is not a hole
139
+ trk_state.is_hole = false ;
140
+
145
141
// Forward filter
146
142
res = sf.template visit_mask <gain_matrix_updater<algebra_t >>(
147
143
trk_state, propagation._stepping .bound_params ());
@@ -151,10 +147,16 @@ struct kalman_actor : detray::actor {
151
147
152
148
// Set full jacobian
153
149
trk_state.jacobian () = stepping.full_jacobian ();
150
+
151
+ // Update iterator
152
+ actor_state.next ();
154
153
} else {
155
154
// Backward filter for smoothing
156
155
res = sf.template visit_mask <two_filters_smoother<algebra_t >>(
157
156
trk_state, propagation._stepping .bound_params ());
157
+
158
+ // Update iterator
159
+ actor_state.previous ();
158
160
}
159
161
160
162
// Abort if the Kalman update fails
@@ -170,9 +172,6 @@ struct kalman_actor : detray::actor {
170
172
stepping.particle_hypothesis (),
171
173
propagation._stepping .bound_params ()));
172
174
173
- // Update iterator
174
- actor_state.next ();
175
-
176
175
// Flag renavigation of the current candidate
177
176
navigation.set_high_trust ();
178
177
}
0 commit comments