13
13
#include " traccc/definitions/primitives.hpp"
14
14
#include " traccc/definitions/qualifiers.hpp"
15
15
#include " traccc/edm/device/finding_global_counter.hpp"
16
+ #include " traccc/edm/device/sort_key.hpp"
16
17
#include " traccc/finding/candidate_link.hpp"
17
18
#include " traccc/finding/device/add_links_for_holes.hpp"
18
19
#include " traccc/finding/device/apply_interaction.hpp"
19
20
#include " traccc/finding/device/build_tracks.hpp"
20
21
#include " traccc/finding/device/count_measurements.hpp"
21
22
#include " traccc/finding/device/find_tracks.hpp"
23
+ #include " traccc/finding/device/get_sort_key_value.hpp"
22
24
#include " traccc/finding/device/make_barcode_sequence.hpp"
23
25
#include " traccc/finding/device/propagate_to_next_surface.hpp"
24
26
#include " traccc/finding/device/prune_tracks.hpp"
@@ -137,13 +139,25 @@ __global__ void add_links_for_holes(
137
139
n_total_candidates);
138
140
}
139
141
142
+ // / CUDA kernel for running @c traccc::device::get_sort_key_value
143
+ __global__ void get_sort_key_value (
144
+ bound_track_parameters_collection_types::const_view params_view,
145
+ vecmem::data::vector_view<device::sort_key> keys_view,
146
+ vecmem::data::vector_view<unsigned int > ids_view) {
147
+
148
+ int gid = threadIdx .x + blockIdx .x * blockDim .x ;
149
+
150
+ device::get_sort_key_value (gid, params_view, keys_view, ids_view);
151
+ }
152
+
140
153
// / CUDA kernel for running @c traccc::device::propagate_to_next_surface
141
154
template <typename propagator_t , typename bfield_t , typename config_t >
142
155
__global__ void propagate_to_next_surface (
143
156
const config_t cfg,
144
157
typename propagator_t ::detector_type::view_type det_data,
145
158
bfield_t field_data,
146
159
bound_track_parameters_collection_types::const_view in_params_view,
160
+ vecmem::data::vector_view<const unsigned int > param_ids_view,
147
161
vecmem::data::vector_view<const candidate_link> links_view,
148
162
const unsigned int step, const unsigned int & n_candidates,
149
163
bound_track_parameters_collection_types::view out_params_view,
@@ -156,9 +170,9 @@ __global__ void propagate_to_next_surface(
156
170
int gid = threadIdx .x + blockIdx .x * blockDim .x ;
157
171
158
172
device::propagate_to_next_surface<propagator_t , bfield_t , config_t >(
159
- gid, cfg, det_data, field_data, in_params_view, links_view, step ,
160
- n_candidates, out_params_view, param_to_link_view, tips_view ,
161
- n_tracks_per_seed_view, n_out_params);
173
+ gid, cfg, det_data, field_data, in_params_view, param_ids_view ,
174
+ links_view, step, n_candidates, out_params_view, param_to_link_view,
175
+ tips_view, n_tracks_per_seed_view, n_out_params);
162
176
}
163
177
164
178
// / CUDA kernel for running @c traccc::device::build_tracks
@@ -457,7 +471,33 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
457
471
m_stream.synchronize ();
458
472
459
473
/* ****************************************************************
460
- * Kernel6: Propagate to the next surface
474
+ * Kernel6: Get key and value for parameter sorting
475
+ *****************************************************************/
476
+
477
+ vecmem::data::vector_buffer<device::sort_key> keys_buffer (
478
+ global_counter_host.n_candidates , m_mr.main );
479
+ vecmem::data::vector_buffer<unsigned int > param_ids_buffer (
480
+ global_counter_host.n_candidates , m_mr.main );
481
+
482
+ nThreads = m_warp_size * 2 ;
483
+
484
+ if (global_counter_host.n_candidates > 0 ) {
485
+ nBlocks =
486
+ (global_counter_host.n_candidates + nThreads - 1 ) / nThreads;
487
+ kernels::get_sort_key_value<<<nBlocks, nThreads, 0 , stream>>> (
488
+ updated_params_buffer, keys_buffer, param_ids_buffer);
489
+ TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
490
+ }
491
+
492
+ // Sort the key and values
493
+ vecmem::device_vector<device::sort_key> keys_device (keys_buffer);
494
+ vecmem::device_vector<unsigned int > param_ids_device (param_ids_buffer);
495
+ thrust::sort_by_key (thrust::cuda::par.on (stream), keys_device.begin (),
496
+ keys_device.end (), param_ids_device.begin (),
497
+ device::sort_key_comp ());
498
+
499
+ /* ****************************************************************
500
+ * Kernel7: Propagate to the next surface
461
501
*****************************************************************/
462
502
463
503
// Buffer for out parameters for the next step
@@ -482,8 +522,9 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
482
522
config_type>
483
523
<<<nBlocks, nThreads, 0 , stream>>> (
484
524
m_cfg, det_view, field_view, updated_params_buffer,
485
- link_map[step], step, (*global_counter_device).n_candidates ,
486
- out_params_buffer, param_to_link_map[step], tips_map[step],
525
+ param_ids_buffer, link_map[step], step,
526
+ (*global_counter_device).n_candidates , out_params_buffer,
527
+ param_to_link_map[step], tips_map[step],
487
528
n_tracks_per_seed_buffer,
488
529
(*global_counter_device).n_out_params );
489
530
TRACCC_CUDA_ERROR_CHECK (cudaGetLastError ());
@@ -569,7 +610,7 @@ finding_algorithm<stepper_t, navigator_t>::operator()(
569
610
}
570
611
571
612
/* ****************************************************************
572
- * Kernel7 : Build tracks
613
+ * Kernel8 : Build tracks
573
614
*****************************************************************/
574
615
575
616
// Create track candidate buffer
0 commit comments