7
7
8
8
// Project include(s).
9
9
#include "../utils/get_queue.hpp"
10
+ #include "traccc/fitting/device/fill_sort_keys.hpp"
10
11
#include "traccc/fitting/device/fit.hpp"
11
12
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
12
13
#include "traccc/sycl/fitting/fitting_algorithm.hpp"
18
19
#include "detray/navigation/navigator.hpp"
19
20
#include "detray/propagator/rk_stepper.hpp"
20
21
22
+ // DPL include(s).
23
+ #include <oneapi/dpl/algorithm>
24
+ #include <oneapi/dpl/execution>
25
+
21
26
// System include(s).
22
27
#include <vector>
23
28
@@ -27,6 +32,9 @@ namespace kernels {
27
32
/// Class identifying the kernel running @c
28
33
/// traccc::device::fit
29
34
class fit;
35
+ /// Class identifying the kernel running @c
36
+ /// traccc::device::fill_sort_keys
37
+ class fill_sort_keys;
30
38
} // namespace kernels
31
39
32
40
template <typename fitter_t>
@@ -74,19 +82,47 @@ track_state_container_types::buffer fitting_algorithm<fitter_t>::operator()(
74
82
// (=32)
75
83
unsigned int localSize = 64;
76
84
85
+ vecmem::data::vector_buffer<device::sort_key> keys_buffer(n_tracks,
86
+ m_mr.main);
87
+ vecmem::data::vector_buffer<unsigned int> param_ids_buffer(n_tracks,
88
+ m_mr.main);
89
+ vecmem::data::vector_view<device::sort_key> keys_view(keys_buffer);
90
+ vecmem::data::vector_view<unsigned int> param_ids_view(param_ids_buffer);
91
+
92
+ // Sort the key to get the sorted parameter ids
93
+ vecmem::device_vector<device::sort_key> keys_device(keys_buffer);
94
+ vecmem::device_vector<unsigned int> param_ids_device(param_ids_buffer);
95
+
77
96
// 1 dim ND Range for the kernel
78
97
auto trackParamsNdRange =
79
98
traccc::sycl::calculate1DimNdRange(n_tracks, localSize);
80
99
100
+ details::get_queue(m_queue)
101
+ .submit([&](::sycl::handler& h) {
102
+ h.parallel_for<kernels::fill_sort_keys>(
103
+ trackParamsNdRange, [track_candidates_view, keys_view,
104
+ param_ids_view](::sycl::nd_item<1> item) {
105
+ device::fill_sort_keys(item.get_global_linear_id(),
106
+ track_candidates_view, keys_view,
107
+ param_ids_view);
108
+ });
109
+ })
110
+ .wait_and_throw();
111
+
112
+ oneapi::dpl::sort_by_key(oneapi::dpl::execution::dpcpp_default,
113
+ keys_device.begin(), keys_device.end(),
114
+ param_ids_device.begin());
115
+
81
116
details::get_queue(m_queue)
82
117
.submit([&](::sycl::handler& h) {
83
118
h.parallel_for<kernels::fit>(
84
119
trackParamsNdRange,
85
120
[det_view, field_view, config = m_cfg, track_candidates_view,
86
- track_states_view](::sycl::nd_item<1> item) {
87
- device::fit<fitter_t>(
88
- item.get_global_linear_id(), det_view, field_view,
89
- config, track_candidates_view, track_states_view);
121
+ param_ids_view, track_states_view](::sycl::nd_item<1> item) {
122
+ device::fit<fitter_t>(item.get_global_linear_id(), det_view,
123
+ field_view, config,
124
+ track_candidates_view, param_ids_view,
125
+ track_states_view);
90
126
});
91
127
})
92
128
.wait_and_throw();
0 commit comments