Skip to content

Commit ed91591

Browse files
committed
Make sure that all SYCL kernels would have a unique kernel class.
Trying to avoid confusion at runtime about which kernel is which.
1 parent 268a632 commit ed91591

6 files changed

+84
-20
lines changed

device/sycl/src/finding/combinatorial_kalman_filter_algorithm_constant_field_default_detector.sycl

+26-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,28 @@
1717
#include <detray/propagator/rk_stepper.hpp>
1818

1919
namespace traccc::sycl {
20+
namespace kernels::combinatorial_kalman_filter_constant_field_default_detector {
21+
22+
struct make_barcode_sequence;
23+
struct apply_interaction;
24+
struct find_tracks;
25+
struct fill_sort_keys;
26+
struct propagate_to_next_surface;
27+
struct build_tracks;
28+
struct prune_tracks;
29+
30+
struct kernels {
31+
using make_barcode_sequence_kernel_type = make_barcode_sequence;
32+
using apply_interaction_kernel_type = apply_interaction;
33+
using find_tracks_kernel_type = find_tracks;
34+
using fill_sort_keys_kernel_type = fill_sort_keys;
35+
using propagate_to_next_surface_kernel_type = propagate_to_next_surface;
36+
using build_tracks_kernel_type = build_tracks;
37+
using prune_tracks_kernel_type = prune_tracks;
38+
}; // namespace kernels
39+
40+
} // namespace
41+
// kernels::combinatorial_kalman_filter_constant_field_default_detector
2042

2143
combinatorial_kalman_filter_algorithm::output_type
2244
combinatorial_kalman_filter_algorithm::operator()(
@@ -30,9 +52,10 @@ combinatorial_kalman_filter_algorithm::operator()(
3052
detray::rk_stepper<detray::bfield::const_field_t::view_t,
3153
default_detector::device::algebra_type,
3254
detray::constrained_step<>>,
33-
detray::navigator<const default_detector::device>>(
34-
det, field, measurements, seeds, m_config, m_mr, m_copy,
35-
details::get_queue(m_queue));
55+
detray::navigator<const default_detector::device>,
56+
kernels::combinatorial_kalman_filter_constant_field_default_detector::
57+
kernels>(det, field, measurements, seeds, m_config, m_mr, m_copy,
58+
details::get_queue(m_queue));
3659
}
3760

3861
} // namespace traccc::sycl

device/sycl/src/finding/combinatorial_kalman_filter_algorithm_constant_field_telescope_detector.sycl

+27-3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,29 @@
1717
#include <detray/propagator/rk_stepper.hpp>
1818

1919
namespace traccc::sycl {
20+
namespace kernels::
21+
combinatorial_kalman_filter_constant_field_telescope_detector {
22+
23+
struct make_barcode_sequence;
24+
struct apply_interaction;
25+
struct find_tracks;
26+
struct fill_sort_keys;
27+
struct propagate_to_next_surface;
28+
struct build_tracks;
29+
struct prune_tracks;
30+
31+
struct kernels {
32+
using make_barcode_sequence_kernel_type = make_barcode_sequence;
33+
using apply_interaction_kernel_type = apply_interaction;
34+
using find_tracks_kernel_type = find_tracks;
35+
using fill_sort_keys_kernel_type = fill_sort_keys;
36+
using propagate_to_next_surface_kernel_type = propagate_to_next_surface;
37+
using build_tracks_kernel_type = build_tracks;
38+
using prune_tracks_kernel_type = prune_tracks;
39+
}; // namespace kernels
40+
41+
} // namespace
42+
// kernels::combinatorial_kalman_filter_constant_field_telescope_detector
2043

2144
combinatorial_kalman_filter_algorithm::output_type
2245
combinatorial_kalman_filter_algorithm::operator()(
@@ -30,9 +53,10 @@ combinatorial_kalman_filter_algorithm::operator()(
3053
detray::rk_stepper<detray::bfield::const_field_t::view_t,
3154
telescope_detector::device::algebra_type,
3255
detray::constrained_step<>>,
33-
detray::navigator<const telescope_detector::device>>(
34-
det, field, measurements, seeds, m_config, m_mr, m_copy,
35-
details::get_queue(m_queue));
56+
detray::navigator<const telescope_detector::device>,
57+
kernels::combinatorial_kalman_filter_constant_field_telescope_detector::
58+
kernels>(det, field, measurements, seeds, m_config, m_mr, m_copy,
59+
details::get_queue(m_queue));
3660
}
3761

3862
} // namespace traccc::sycl

device/sycl/src/finding/find_tracks.hpp

+13-8
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ namespace traccc::sycl::details {
5959
///
6060
/// @tparam stepper_t The stepper type used for the track propagation
6161
/// @tparam navigator_t The navigator type used for the track navigation
62+
/// @tparam kernels_t Structure with unique "kernel structures"
6263
///
6364
/// @param det A view of the detector object
6465
/// @param field The magnetic field object
@@ -72,7 +73,7 @@ namespace traccc::sycl::details {
7273
///
7374
/// @return A buffer of the found track candidates
7475
///
75-
template <typename stepper_t, typename navigator_t>
76+
template <typename stepper_t, typename navigator_t, typename kernels_t>
7677
track_candidate_container_types::buffer find_tracks(
7778
const typename navigator_t::detector_type::view_type& det,
7879
const typename stepper_t::magnetic_field_type& field,
@@ -129,7 +130,8 @@ track_candidate_container_types::buffer find_tracks(
129130

130131
queue
131132
.submit([&](::sycl::handler& h) {
132-
h.parallel_for(
133+
h.parallel_for<
134+
typename kernels_t::make_barcode_sequence_kernel_type>(
133135
calculate1DimNdRange(n_modules, 64),
134136
[uniques_view = vecmem::get_data(uniques_buffer),
135137
barcodes_view = vecmem::get_data(barcodes_buffer)](
@@ -182,7 +184,8 @@ track_candidate_container_types::buffer find_tracks(
182184

183185
queue
184186
.submit([&](::sycl::handler& h) {
185-
h.parallel_for(
187+
h.parallel_for<
188+
typename kernels_t::apply_interaction_kernel_type>(
186189
calculate1DimNdRange(n_in_params, 64),
187190
[config, det, n_in_params,
188191
in_params = vecmem::get_data(in_params_buffer),
@@ -244,7 +247,7 @@ track_candidate_container_types::buffer find_tracks(
244247
shared_candidates_size(1, h);
245248

246249
// Launch the kernel.
247-
h.parallel_for(
250+
h.parallel_for<typename kernels_t::find_tracks_kernel_type>(
248251
calculate1DimNdRange(n_in_params, nFindTracksThreads),
249252
[config, det, measurements,
250253
in_params = vecmem::get_data(in_params_buffer),
@@ -308,7 +311,8 @@ track_candidate_container_types::buffer find_tracks(
308311

309312
queue
310313
.submit([&](::sycl::handler& h) {
311-
h.parallel_for(
314+
h.parallel_for<
315+
typename kernels_t::fill_sort_keys_kernel_type>(
312316
calculate1DimNdRange(n_candidates, 256),
313317
[in_params = vecmem::get_data(in_params_buffer),
314318
keys = vecmem::get_data(keys_buffer),
@@ -356,7 +360,8 @@ track_candidate_container_types::buffer find_tracks(
356360
// surface.
357361
queue
358362
.submit([&](::sycl::handler& h) {
359-
h.parallel_for(
363+
h.parallel_for<typename kernels_t::
364+
propagate_to_next_surface_kernel_type>(
360365
calculate1DimNdRange(n_candidates, 64),
361366
[config, det, field,
362367
in_params = vecmem::get_data(in_params_buffer),
@@ -440,7 +445,7 @@ track_candidate_container_types::buffer find_tracks(
440445

441446
queue
442447
.submit([&](::sycl::handler& h) {
443-
h.parallel_for(
448+
h.parallel_for<typename kernels_t::build_tracks_kernel_type>(
444449
calculate1DimNdRange(n_tips_total, 64),
445450
[config, measurements, seeds,
446451
links = vecmem::get_data(links_buffer),
@@ -478,7 +483,7 @@ track_candidate_container_types::buffer find_tracks(
478483

479484
queue
480485
.submit([&](::sycl::handler& h) {
481-
h.parallel_for(
486+
h.parallel_for<typename kernels_t::prune_tracks_kernel_type>(
482487
calculate1DimNdRange(n_valid_tracks, 64),
483488
[track_candidates,
484489
valid_indices = vecmem::get_data(valid_indices_buffer),

device/sycl/src/seeding/silicon_pixel_spacepoint_formation.hpp

+3-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ namespace traccc::sycl::details {
2828
/// functions
2929
///
3030
/// @tparam detector_t The detector type to use
31+
/// @tparam kernel_t The kernel type to use
3132
///
3233
/// @param det_view The view of the detector to use
3334
/// @param measurements_view The view of the measurements to process
@@ -36,7 +37,7 @@ namespace traccc::sycl::details {
3637
/// @param queue The queue to use for the computation
3738
/// @return A buffer of the created spacepoints
3839
///
39-
template <typename detector_t>
40+
template <typename detector_t, typename kernel_t>
4041
spacepoint_collection_types::buffer silicon_pixel_spacepoint_formation(
4142
const typename detector_t::view_type& det_view,
4243
const measurement_collection_types::const_view& measurements_view,
@@ -64,7 +65,7 @@ spacepoint_collection_types::buffer silicon_pixel_spacepoint_formation(
6465
// Run the spacepoint formation on the device.
6566
queue
6667
.submit([&](cl::sycl::handler& h) {
67-
h.parallel_for(
68+
h.parallel_for<kernel_t>(
6869
countRange, [det_view, measurements_view, n_measurements,
6970
spacepoints_view = vecmem::get_data(result)](
7071
cl::sycl::nd_item<1> item) {

device/sycl/src/seeding/silicon_pixel_spacepoint_formation_algorithm_default_detector.sycl

+7-2
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,20 @@
1111
#include "traccc/sycl/seeding/silicon_pixel_spacepoint_formation_algorithm.hpp"
1212

1313
namespace traccc::sycl {
14+
namespace kernels {
15+
16+
struct form_spacepoints_default_detector;
17+
18+
} // namespace kernels
1419

1520
silicon_pixel_spacepoint_formation_algorithm::output_type
1621
silicon_pixel_spacepoint_formation_algorithm::operator()(
1722
const default_detector::view& det,
1823
const measurement_collection_types::const_view& meas) const {
1924

2025
return details::silicon_pixel_spacepoint_formation<
21-
default_detector::device>(det, meas, m_mr.main, m_copy,
22-
details::get_queue(m_queue));
26+
default_detector::device, kernels::form_spacepoints_default_detector>(
27+
det, meas, m_mr.main, m_copy, details::get_queue(m_queue));
2328
}
2429

2530
} // namespace traccc::sycl

device/sycl/src/seeding/silicon_pixel_spacepoint_formation_algorithm_telescope_detector.sycl

+8-2
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,21 @@
1111
#include "traccc/sycl/seeding/silicon_pixel_spacepoint_formation_algorithm.hpp"
1212

1313
namespace traccc::sycl {
14+
namespace kernels {
15+
16+
struct form_spacepoints_telescope_detector;
17+
18+
} // namespace kernels
1419

1520
silicon_pixel_spacepoint_formation_algorithm::output_type
1621
silicon_pixel_spacepoint_formation_algorithm::operator()(
1722
const telescope_detector::view& det,
1823
const measurement_collection_types::const_view& meas) const {
1924

2025
return details::silicon_pixel_spacepoint_formation<
21-
telescope_detector::device>(det, meas, m_mr.main, m_copy,
22-
details::get_queue(m_queue));
26+
telescope_detector::device,
27+
kernels::form_spacepoints_telescope_detector>(
28+
det, meas, m_mr.main, m_copy, details::get_queue(m_queue));
2329
}
2430

2531
} // namespace traccc::sycl

0 commit comments

Comments
 (0)