Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 8680ae5

Browse files
committedSep 20, 2024·
Access track parameters in a sorted order to minimize the branching divergence
1 parent 595dad9 commit 8680ae5

13 files changed

+295
-21
lines changed
 

‎device/common/CMakeLists.txt

+5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ traccc_add_library( traccc_device_common device_common TYPE SHARED
2626
"include/traccc/edm/device/triplet_counter.hpp"
2727
"include/traccc/edm/device/device_doublet.hpp"
2828
"include/traccc/edm/device/device_triplet.hpp"
29+
"include/traccc/edm/device/sort_key.hpp"
2930
# Clusterization function(s).
3031
"include/traccc/clusterization/device/form_spacepoints.hpp"
3132
"include/traccc/clusterization/device/impl/form_spacepoints.ipp"
@@ -65,6 +66,7 @@ traccc_add_library( traccc_device_common device_common TYPE SHARED
6566
"include/traccc/finding/device/build_tracks.hpp"
6667
"include/traccc/finding/device/count_measurements.hpp"
6768
"include/traccc/finding/device/find_tracks.hpp"
69+
"include/traccc/finding/device/get_sort_key_value.hpp"
6870
"include/traccc/finding/device/add_links_for_holes.hpp"
6971
"include/traccc/finding/device/make_barcode_sequence.hpp"
7072
"include/traccc/finding/device/propagate_to_next_surface.hpp"
@@ -73,13 +75,16 @@ traccc_add_library( traccc_device_common device_common TYPE SHARED
7375
"include/traccc/finding/device/impl/build_tracks.ipp"
7476
"include/traccc/finding/device/impl/count_measurements.ipp"
7577
"include/traccc/finding/device/impl/find_tracks.ipp"
78+
"include/traccc/finding/device/impl/get_sort_key_value.ipp"
7679
"include/traccc/finding/device/impl/add_links_for_holes.ipp"
7780
"include/traccc/finding/device/impl/make_barcode_sequence.ipp"
7881
"include/traccc/finding/device/impl/propagate_to_next_surface.ipp"
7982
"include/traccc/finding/device/impl/prune_tracks.ipp"
8083
# Track fitting funtions(s).
8184
"include/traccc/fitting/device/fit.hpp"
8285
"include/traccc/fitting/device/impl/fit.ipp"
86+
"include/traccc/fitting/device/get_sort_key_value.hpp"
87+
"include/traccc/fitting/device/impl/get_sort_key_value.ipp"
8388
)
8489
target_link_libraries( traccc_device_common
8590
PUBLIC traccc::Thrust traccc::core vecmem::core )
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2024 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
// Project include(s).
11+
#include "traccc/definitions/primitives.hpp"
12+
#include "traccc/edm/track_candidate.hpp"
13+
#include "traccc/edm/track_parameters.hpp"
14+
15+
namespace traccc::device {
16+
17+
struct sort_key {
18+
traccc::scalar key;
19+
};
20+
21+
TRACCC_HOST_DEVICE
22+
inline sort_key get_sort_key(const bound_track_parameters& params) {
23+
// key = |theta - pi/2|
24+
return sort_key{math::abs(params.theta() - constant<traccc::scalar>::pi_2)};
25+
}
26+
27+
TRACCC_HOST_DEVICE
28+
inline sort_key get_sort_key(
29+
const track_candidate_collection_types::const_device& candidates) {
30+
// Number of candidates
31+
return sort_key{static_cast<traccc::scalar>(candidates.size())};
32+
}
33+
34+
/// Comparator based on key
35+
struct sort_key_comp {
36+
TRACCC_HOST_DEVICE
37+
bool operator()(const sort_key& lhs, const sort_key& rhs) {
38+
return lhs.key < rhs.key;
39+
}
40+
};
41+
42+
} // namespace traccc::device
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2024 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
// Project include(s).
11+
#include "traccc/edm/device/sort_key.hpp"
12+
#include "traccc/edm/track_candidate.hpp"
13+
14+
// System include(s).
15+
#include <cstddef>
16+
17+
namespace traccc::device {
18+
19+
/// Function used for fitting a track for a given track candidates
20+
///
21+
/// @param[in] globalIndex The index of the current thread
22+
/// @param[out] keys_view The key values
23+
/// @param[out] ids_view The param ids
24+
///
25+
TRACCC_HOST_DEVICE inline void get_sort_key_value(
26+
std::size_t globalIndex,
27+
bound_track_parameters_collection_types::const_view params_view,
28+
vecmem::data::vector_view<device::sort_key> keys_view,
29+
vecmem::data::vector_view<unsigned int> ids_view);
30+
31+
} // namespace traccc::device
32+
33+
// Include the implementation.
34+
#include "traccc/finding/device/impl/get_sort_key_value.ipp"

‎device/common/include/traccc/finding/device/impl/find_tracks.ipp

+1-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ TRACCC_DEVICE inline void find_tracks(
145145
n_candidates[in_param_id]);
146146
num_candidates.fetch_add(1);
147147

148-
out_params[l_pos] = trk_state.filtered();
148+
out_params.at(l_pos) = trk_state.filtered();
149149
}
150150
}
151151
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2024 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
namespace traccc::device {
11+
12+
TRACCC_HOST_DEVICE inline void get_sort_key_value(
13+
std::size_t globalIndex,
14+
bound_track_parameters_collection_types::const_view params_view,
15+
vecmem::data::vector_view<device::sort_key> keys_view,
16+
vecmem::data::vector_view<unsigned int> ids_view) {
17+
18+
bound_track_parameters_collection_types::const_device params(params_view);
19+
20+
// Keys
21+
vecmem::device_vector<device::sort_key> keys_device(keys_view);
22+
23+
// Param id
24+
vecmem::device_vector<unsigned int> ids_device(ids_view);
25+
26+
if (globalIndex >= keys_device.size()) {
27+
return;
28+
}
29+
30+
keys_device.at(globalIndex) = device::get_sort_key(params.at(globalIndex));
31+
ids_device.at(globalIndex) = globalIndex;
32+
}
33+
34+
} // namespace traccc::device

‎device/common/include/traccc/finding/device/impl/propagate_to_next_surface.ipp

+13-7
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
1515
typename propagator_t::detector_type::view_type det_data,
1616
bfield_t field_data,
1717
bound_track_parameters_collection_types::const_view in_params_view,
18+
vecmem::data::vector_view<const unsigned int> param_ids_view,
1819
vecmem::data::vector_view<const candidate_link> links_view,
1920
const unsigned int step, const unsigned int& n_in_params,
2021
bound_track_parameters_collection_types::view out_params_view,
@@ -28,6 +29,11 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
2829
return;
2930
}
3031

32+
// Theta id
33+
vecmem::device_vector<const unsigned int> param_ids(param_ids_view);
34+
35+
const unsigned int param_id = param_ids.at(globalIndex);
36+
3137
// Number of tracks per seed
3238
vecmem::device_vector<unsigned int> n_tracks_per_seed(
3339
n_tracks_per_seed_view);
@@ -36,7 +42,7 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
3642
vecmem::device_vector<const candidate_link> links(links_view);
3743

3844
// Seed id
39-
unsigned int orig_param_id = links.at(globalIndex).seed_idx;
45+
unsigned int orig_param_id = links.at(param_id).seed_idx;
4046

4147
// Count the number of tracks per seed
4248
vecmem::device_atomic_ref<unsigned int> num_tracks_per_seed(
@@ -52,8 +58,8 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
5258
vecmem::device_vector<typename candidate_link::link_index_type> tips(
5359
tips_view);
5460

55-
if (links[globalIndex].n_skipped > cfg.max_num_skipping_per_cand) {
56-
tips.push_back({step, globalIndex});
61+
if (links.at(param_id).n_skipped > cfg.max_num_skipping_per_cand) {
62+
tips.push_back({step, param_id});
5763
return;
5864
}
5965

@@ -71,7 +77,7 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
7177
vecmem::device_vector<unsigned int> param_to_link(param_to_link_view);
7278

7379
// Input bound track parameter
74-
const bound_track_parameters in_par = in_params.at(globalIndex);
80+
const bound_track_parameters in_par = in_params.at(param_id);
7581

7682
// Create propagator
7783
propagator_t propagator(cfg.propagation);
@@ -115,17 +121,17 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
115121

116122
out_params[out_param_id] = propagation._stepping._bound_params;
117123

118-
param_to_link[out_param_id] = static_cast<unsigned int>(globalIndex);
124+
param_to_link[out_param_id] = param_id;
119125
}
120126
// Unless the track found a surface, it is considered a tip
121127
else if (!s4.success && step >= cfg.min_track_candidates_per_track - 1) {
122-
tips.push_back({step, globalIndex});
128+
tips.push_back({step, param_id});
123129
}
124130

125131
// If no more CKF step is expected, current candidate is
126132
// kept as a tip
127133
if (s4.success && step == cfg.max_track_candidates_per_track - 1) {
128-
tips.push_back({step, globalIndex});
134+
tips.push_back({step, param_id});
129135
}
130136
}
131137

‎device/common/include/traccc/finding/device/propagate_to_next_surface.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ namespace traccc::device {
2626
/// @param[in] cfg Track finding config object
2727
/// @param[in] det_data Detector view object
2828
/// @param[in] in_params_view Input parameters
29+
/// @param[in] param_ids_view Sorted param ids
2930
/// @param[in] links_view Link container for the current step
3031
/// @param[in] step Step index
3132
/// @param[in] n_in_params The number of input parameters
@@ -41,6 +42,7 @@ TRACCC_DEVICE inline void propagate_to_next_surface(
4142
typename propagator_t::detector_type::view_type det_data,
4243
bfield_t field_data,
4344
bound_track_parameters_collection_types::const_view in_params_view,
45+
vecmem::data::vector_view<const unsigned int> param_ids_view,
4446
vecmem::data::vector_view<const candidate_link> links_view,
4547
const unsigned int step, const unsigned int& n_in_params,
4648
bound_track_parameters_collection_types::view out_params_view,

‎device/common/include/traccc/fitting/device/fit.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ TRACCC_HOST_DEVICE inline void fit(
3131
const typename fitter_t::bfield_type field_data,
3232
const typename fitter_t::config_type cfg,
3333
track_candidate_container_types::const_view track_candidates_view,
34+
vecmem::data::vector_view<const unsigned int> param_ids_view,
3435
track_state_container_types::view track_states_view);
3536

3637
} // namespace traccc::device
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2024 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
// Project include(s).
11+
#include "traccc/edm/device/sort_key.hpp"
12+
#include "traccc/edm/track_candidate.hpp"
13+
14+
// System include(s).
15+
#include <cstddef>
16+
17+
namespace traccc::device {
18+
19+
/// Function used for fitting a track for a given track candidates
20+
///
21+
/// @param[in] globalIndex The index of the current thread
22+
/// @param[in] track_candidates_view The input track candidates
23+
/// @param[out] keys_view The key values
24+
/// @param[out] ids_view The param ids
25+
///
26+
TRACCC_HOST_DEVICE inline void get_sort_key_value(
27+
std::size_t globalIndex,
28+
track_candidate_container_types::const_view track_candidates_view,
29+
vecmem::data::vector_view<device::sort_key> keys_view,
30+
vecmem::data::vector_view<unsigned int> ids_view);
31+
32+
} // namespace traccc::device
33+
34+
// Include the implementation.
35+
#include "traccc/fitting/device/impl/get_sort_key_value.ipp"

‎device/common/include/traccc/fitting/device/impl/fit.ipp

+9-4
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@ TRACCC_HOST_DEVICE inline void fit(
1616
const typename fitter_t::bfield_type field_data,
1717
const typename fitter_t::config_type cfg,
1818
track_candidate_container_types::const_view track_candidates_view,
19+
vecmem::data::vector_view<const unsigned int> param_ids_view,
1920
track_state_container_types::view track_states_view) {
2021

2122
typename fitter_t::detector_type det(det_data);
2223

2324
track_candidate_container_types::const_device track_candidates(
2425
track_candidates_view);
2526

27+
vecmem::device_vector<const unsigned int> param_ids(param_ids_view);
28+
2629
track_state_container_types::device track_states(track_states_view);
2730

2831
fitter_t fitter(det, field_data, cfg);
@@ -31,15 +34,17 @@ TRACCC_HOST_DEVICE inline void fit(
3134
return;
3235
}
3336

37+
const unsigned int param_id = param_ids.at(globalIndex);
38+
3439
// Track candidates per track
3540
const auto& track_candidates_per_track =
36-
track_candidates[globalIndex].items;
41+
track_candidates.at(param_id).items;
3742

3843
// Seed parameter
39-
const auto& seed_param = track_candidates[globalIndex].header;
44+
const auto& seed_param = track_candidates.at(param_id).header;
4045

4146
// Track states per track
42-
auto track_states_per_track = track_states[globalIndex].items;
47+
auto track_states_per_track = track_states.at(param_id).items;
4348

4449
for (auto& cand : track_candidates_per_track) {
4550
track_states_per_track.emplace_back(cand);
@@ -51,7 +56,7 @@ TRACCC_HOST_DEVICE inline void fit(
5156
fitter.fit(seed_param, fitter_state);
5257

5358
// Get the final fitting information
54-
track_states[globalIndex].header = fitter_state.m_fit_res;
59+
track_states.at(param_id).header = fitter_state.m_fit_res;
5560
}
5661

5762
} // namespace traccc::device
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2024 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
namespace traccc::device {
11+
12+
TRACCC_HOST_DEVICE inline void get_sort_key_value(
13+
std::size_t globalIndex,
14+
track_candidate_container_types::const_view track_candidates_view,
15+
vecmem::data::vector_view<device::sort_key> keys_view,
16+
vecmem::data::vector_view<unsigned int> ids_view) {
17+
18+
track_candidate_container_types::const_device track_candidates(
19+
track_candidates_view);
20+
21+
if (globalIndex >= track_candidates.size()) {
22+
return;
23+
}
24+
25+
// Keys
26+
vecmem::device_vector<device::sort_key> keys_device(keys_view);
27+
28+
// Param id
29+
vecmem::device_vector<unsigned int> ids_device(ids_view);
30+
31+
keys_device.at(globalIndex) =
32+
device::get_sort_key(track_candidates.at(globalIndex).items);
33+
ids_device.at(globalIndex) = globalIndex;
34+
}
35+
36+
} // namespace traccc::device

‎device/cuda/src/finding/finding_algorithm.cu

+48-7
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
#include "traccc/definitions/primitives.hpp"
1414
#include "traccc/definitions/qualifiers.hpp"
1515
#include "traccc/edm/device/finding_global_counter.hpp"
16+
#include "traccc/edm/device/sort_key.hpp"
1617
#include "traccc/finding/candidate_link.hpp"
1718
#include "traccc/finding/device/add_links_for_holes.hpp"
1819
#include "traccc/finding/device/apply_interaction.hpp"
1920
#include "traccc/finding/device/build_tracks.hpp"
2021
#include "traccc/finding/device/count_measurements.hpp"
2122
#include "traccc/finding/device/find_tracks.hpp"
23+
#include "traccc/finding/device/get_sort_key_value.hpp"
2224
#include "traccc/finding/device/make_barcode_sequence.hpp"
2325
#include "traccc/finding/device/propagate_to_next_surface.hpp"
2426
#include "traccc/finding/device/prune_tracks.hpp"
@@ -137,13 +139,25 @@ __global__ void add_links_for_holes(
137139
n_total_candidates);
138140
}
139141

142+
/// CUDA kernel for running @c traccc::device::get_sort_key_value
143+
__global__ void get_sort_key_value(
144+
bound_track_parameters_collection_types::const_view params_view,
145+
vecmem::data::vector_view<device::sort_key> keys_view,
146+
vecmem::data::vector_view<unsigned int> ids_view) {
147+
148+
int gid = threadIdx.x + blockIdx.x * blockDim.x;
149+
150+
device::get_sort_key_value(gid, params_view, keys_view, ids_view);
151+
}
152+
140153
/// CUDA kernel for running @c traccc::device::propagate_to_next_surface
141154
template <typename propagator_t, typename bfield_t, typename config_t>
142155
__global__ void propagate_to_next_surface(
143156
const config_t cfg,
144157
typename propagator_t::detector_type::view_type det_data,
145158
bfield_t field_data,
146159
bound_track_parameters_collection_types::const_view in_params_view,
160+
vecmem::data::vector_view<const unsigned int> param_ids_view,
147161
vecmem::data::vector_view<const candidate_link> links_view,
148162
const unsigned int step, const unsigned int& n_candidates,
149163
bound_track_parameters_collection_types::view out_params_view,
@@ -156,9 +170,9 @@ __global__ void propagate_to_next_surface(
156170
int gid = threadIdx.x + blockIdx.x * blockDim.x;
157171

158172
device::propagate_to_next_surface<propagator_t, bfield_t, config_t>(
159-
gid, cfg, det_data, field_data, in_params_view, links_view, step,
160-
n_candidates, out_params_view, param_to_link_view, tips_view,
161-
n_tracks_per_seed_view, n_out_params);
173+
gid, cfg, det_data, field_data, in_params_view, param_ids_view,
174+
links_view, step, n_candidates, out_params_view, param_to_link_view,
175+
tips_view, n_tracks_per_seed_view, n_out_params);
162176
}
163177

164178
/// CUDA kernel for running @c traccc::device::build_tracks
@@ -457,7 +471,33 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
457471
m_stream.synchronize();
458472

459473
/*****************************************************************
460-
* Kernel6: Propagate to the next surface
474+
* Kernel6: Get key and value for parameter sorting
475+
*****************************************************************/
476+
477+
vecmem::data::vector_buffer<device::sort_key> keys_buffer(
478+
global_counter_host.n_candidates, m_mr.main);
479+
vecmem::data::vector_buffer<unsigned int> param_ids_buffer(
480+
global_counter_host.n_candidates, m_mr.main);
481+
482+
nThreads = m_warp_size * 2;
483+
484+
if (global_counter_host.n_candidates > 0) {
485+
nBlocks =
486+
(global_counter_host.n_candidates + nThreads - 1) / nThreads;
487+
kernels::get_sort_key_value<<<nBlocks, nThreads, 0, stream>>>(
488+
updated_params_buffer, keys_buffer, param_ids_buffer);
489+
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
490+
}
491+
492+
// Sort the key and values
493+
vecmem::device_vector<device::sort_key> keys_device(keys_buffer);
494+
vecmem::device_vector<unsigned int> param_ids_device(param_ids_buffer);
495+
thrust::sort_by_key(thrust::cuda::par.on(stream), keys_device.begin(),
496+
keys_device.end(), param_ids_device.begin(),
497+
device::sort_key_comp());
498+
499+
/*****************************************************************
500+
* Kernel7: Propagate to the next surface
461501
*****************************************************************/
462502

463503
// Buffer for out parameters for the next step
@@ -482,8 +522,9 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
482522
config_type>
483523
<<<nBlocks, nThreads, 0, stream>>>(
484524
m_cfg, det_view, field_view, updated_params_buffer,
485-
link_map[step], step, (*global_counter_device).n_candidates,
486-
out_params_buffer, param_to_link_map[step], tips_map[step],
525+
param_ids_buffer, link_map[step], step,
526+
(*global_counter_device).n_candidates, out_params_buffer,
527+
param_to_link_map[step], tips_map[step],
487528
n_tracks_per_seed_buffer,
488529
(*global_counter_device).n_out_params);
489530
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
@@ -569,7 +610,7 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
569610
}
570611

571612
/*****************************************************************
572-
* Kernel7: Build tracks
613+
* Kernel8: Build tracks
573614
*****************************************************************/
574615

575616
// Create track candidate buffer

‎device/cuda/src/fitting/fitting_algorithm.cu

+35-2
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,46 @@
1010
#include "../utils/utils.hpp"
1111
#include "traccc/cuda/fitting/fitting_algorithm.hpp"
1212
#include "traccc/fitting/device/fit.hpp"
13+
#include "traccc/fitting/device/get_sort_key_value.hpp"
1314
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
1415

1516
// detray include(s).
1617
#include "detray/core/detector_metadata.hpp"
1718
#include "detray/detectors/bfield.hpp"
1819
#include "detray/propagator/rk_stepper.hpp"
1920

21+
// Thrust include(s).
22+
#include <thrust/sort.h>
23+
2024
// System include(s).
2125
#include <vector>
2226

2327
namespace traccc::cuda {
2428

2529
namespace kernels {
2630

31+
__global__ void get_sort_key_value(
32+
track_candidate_container_types::const_view track_candidates_view,
33+
vecmem::data::vector_view<device::sort_key> keys_view,
34+
vecmem::data::vector_view<unsigned int> ids_view) {
35+
36+
int gid = threadIdx.x + blockIdx.x * blockDim.x;
37+
38+
device::get_sort_key_value(gid, track_candidates_view, keys_view, ids_view);
39+
}
40+
2741
template <typename fitter_t, typename detector_view_t>
2842
__global__ void fit(
2943
detector_view_t det_data, const typename fitter_t::bfield_type field_data,
3044
const typename fitter_t::config_type cfg,
3145
track_candidate_container_types::const_view track_candidates_view,
46+
vecmem::data::vector_view<const unsigned int> param_ids_view,
3247
track_state_container_types::view track_states_view) {
3348

3449
int gid = threadIdx.x + blockIdx.x * blockDim.x;
3550

3651
device::fit<fitter_t>(gid, det_data, field_data, cfg, track_candidates_view,
37-
track_states_view);
52+
param_ids_view, track_states_view);
3853
}
3954

4055
} // namespace kernels
@@ -76,16 +91,34 @@ track_state_container_types::buffer fitting_algorithm<fitter_t>::operator()(
7691
m_copy.setup(track_states_buffer.headers);
7792
m_copy.setup(track_states_buffer.items);
7893

94+
vecmem::data::vector_buffer<device::sort_key> keys_buffer(n_tracks,
95+
m_mr.main);
96+
vecmem::data::vector_buffer<unsigned int> param_ids_buffer(n_tracks,
97+
m_mr.main);
98+
7999
// Calculate the number of threads and thread blocks to run the track
80100
// fitting
81101
if (n_tracks > 0) {
82102
const unsigned int nThreads = m_warp_size * 2;
83103
const unsigned int nBlocks = (n_tracks + nThreads - 1) / nThreads;
84104

105+
// Get key and value for sorting
106+
kernels::get_sort_key_value<<<nBlocks, nThreads, 0, stream>>>(
107+
track_candidates_view, keys_buffer, param_ids_buffer);
108+
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
109+
110+
// Sort the key to get the sorted parameter ids
111+
vecmem::device_vector<device::sort_key> keys_device(keys_buffer);
112+
vecmem::device_vector<unsigned int> param_ids_device(param_ids_buffer);
113+
114+
thrust::sort_by_key(thrust::cuda::par.on(stream), keys_device.begin(),
115+
keys_device.end(), param_ids_device.begin(),
116+
device::sort_key_comp());
117+
85118
// Run the track fitting
86119
kernels::fit<fitter_t><<<nBlocks, nThreads, 0, stream>>>(
87120
det_view, field_view, m_cfg, track_candidates_view,
88-
track_states_buffer);
121+
param_ids_buffer, track_states_buffer);
89122
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
90123
}
91124

0 commit comments

Comments
 (0)
Please sign in to comment.