Skip to content

Commit b1b8ee4

Browse files
committed
Access the track parametes in a sorted order for KF propagation
1 parent e273246 commit b1b8ee4

File tree

14 files changed

+342
-12
lines changed

14 files changed

+342
-12
lines changed

CMakeLists.txt

+16
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,22 @@ if( TRACCC_SETUP_TBB )
160160
endif()
161161
endif()
162162

163+
# Set up DPL if SYCL is built.
164+
if( TRACCC_BUILD_SYCL )
165+
option( TRACCC_SETUP_DPL
166+
"Set up the DPL target(s) explicitly" TRUE )
167+
option( TRACCC_USE_SYSTEM_DPL
168+
"Pick up an existing installation of DPL from the build environment"
169+
${TRACCC_USE_SYSTEM_LIBS} )
170+
if( TRACCC_SETUP_DPL )
171+
if( TRACCC_USE_SYSTEM_DPL )
172+
find_package( DPL REQUIRED )
173+
else()
174+
add_subdirectory( extern/dpl )
175+
endif()
176+
endif()
177+
endif()
178+
163179
# Set up Kokkos.
164180
option( TRACCC_SETUP_KOKKOS
165181
"Set up the Kokkos library" ${TRACCC_BUILD_KOKKOS} )

device/common/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ traccc_add_library( traccc_device_common device_common TYPE SHARED
7979
# Track fitting funtions(s).
8080
"include/traccc/fitting/device/fit.hpp"
8181
"include/traccc/fitting/device/impl/fit.ipp"
82+
"include/traccc/fitting/device/fill_sort_keys.hpp"
83+
"include/traccc/fitting/device/impl/fill_sort_keys.ipp"
8284
)
8385
target_link_libraries( traccc_device_common
8486
PUBLIC traccc::Thrust traccc::core vecmem::core )

device/common/include/traccc/edm/device/sort_key.hpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
// Project include(s).
1111
#include "traccc/definitions/primitives.hpp"
12+
#include "traccc/edm/track_candidate.hpp"
1213
#include "traccc/edm/track_parameters.hpp"
1314

1415
namespace traccc::device {
@@ -18,7 +19,7 @@ using sort_key = traccc::scalar;
1819
TRACCC_HOST_DEVICE
1920
inline sort_key get_sort_key(const bound_track_parameters& params) {
2021
// key = |theta - pi/2|
21-
return math::abs(params.theta() - constant<traccc::scalar>::pi_2);
22+
return math::fabs(params.theta() - constant<traccc::scalar>::pi_2);
2223
}
2324

2425
} // namespace traccc::device
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2024 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
// Project include(s).
11+
#include "traccc/edm/device/sort_key.hpp"
12+
#include "traccc/edm/track_candidate.hpp"
13+
14+
// System include(s).
15+
#include <cstddef>
16+
17+
namespace traccc::device {
18+
19+
/// Function used to fill key container
20+
///
21+
/// @param[in] globalIndex The index of the current thread
22+
/// @param[in] track_candidates_view The input track candidates
23+
/// @param[out] keys_view The key values
24+
/// @param[out] ids_view The param ids
25+
///
26+
TRACCC_HOST_DEVICE inline void fill_sort_keys(
27+
std::size_t globalIndex,
28+
const track_candidate_container_types::const_view& track_candidates_view,
29+
vecmem::data::vector_view<device::sort_key> keys_view,
30+
vecmem::data::vector_view<unsigned int> ids_view);
31+
32+
} // namespace traccc::device
33+
34+
// Include the implementation.
35+
#include "traccc/fitting/device/impl/fill_sort_keys.ipp"

device/common/include/traccc/fitting/device/fit.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ TRACCC_HOST_DEVICE inline void fit(
3131
const typename fitter_t::bfield_type field_data,
3232
const typename fitter_t::config_type cfg,
3333
track_candidate_container_types::const_view track_candidates_view,
34+
const vecmem::data::vector_view<const unsigned int>& param_ids_view,
3435
track_state_container_types::view track_states_view);
3536

3637
} // namespace traccc::device
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
/** TRACCC library, part of the ACTS project (R&D line)
2+
*
3+
* (c) 2024 CERN for the benefit of the ACTS project
4+
*
5+
* Mozilla Public License Version 2.0
6+
*/
7+
8+
#pragma once
9+
10+
namespace traccc::device {
11+
12+
TRACCC_HOST_DEVICE inline void fill_sort_keys(
13+
std::size_t globalIndex,
14+
const track_candidate_container_types::const_view& track_candidates_view,
15+
vecmem::data::vector_view<device::sort_key> keys_view,
16+
vecmem::data::vector_view<unsigned int> ids_view) {
17+
18+
track_candidate_container_types::const_device track_candidates(
19+
track_candidates_view);
20+
21+
// Keys
22+
vecmem::device_vector<device::sort_key> keys_device(keys_view);
23+
24+
// Param id
25+
vecmem::device_vector<unsigned int> ids_device(ids_view);
26+
27+
if (globalIndex >= keys_device.size()) {
28+
return;
29+
}
30+
31+
// Key = The number of measurements
32+
keys_device.at(globalIndex) = static_cast<traccc::scalar>(
33+
track_candidates.at(globalIndex).items.size());
34+
ids_device.at(globalIndex) = globalIndex;
35+
}
36+
37+
} // namespace traccc::device

device/common/include/traccc/fitting/device/impl/fit.ipp

+9-4
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,16 @@ TRACCC_HOST_DEVICE inline void fit(
1616
const typename fitter_t::bfield_type field_data,
1717
const typename fitter_t::config_type cfg,
1818
track_candidate_container_types::const_view track_candidates_view,
19+
const vecmem::data::vector_view<const unsigned int>& param_ids_view,
1920
track_state_container_types::view track_states_view) {
2021

2122
typename fitter_t::detector_type det(det_data);
2223

2324
track_candidate_container_types::const_device track_candidates(
2425
track_candidates_view);
2526

27+
vecmem::device_vector<const unsigned int> param_ids(param_ids_view);
28+
2629
track_state_container_types::device track_states(track_states_view);
2730

2831
fitter_t fitter(det, field_data, cfg);
@@ -31,15 +34,17 @@ TRACCC_HOST_DEVICE inline void fit(
3134
return;
3235
}
3336

37+
const unsigned int param_id = param_ids.at(globalIndex);
38+
3439
// Track candidates per track
3540
const auto& track_candidates_per_track =
36-
track_candidates[globalIndex].items;
41+
track_candidates.at(param_id).items;
3742

3843
// Seed parameter
39-
const auto& seed_param = track_candidates[globalIndex].header;
44+
const auto& seed_param = track_candidates.at(param_id).header;
4045

4146
// Track states per track
42-
auto track_states_per_track = track_states[globalIndex].items;
47+
auto track_states_per_track = track_states.at(param_id).items;
4348

4449
for (auto& cand : track_candidates_per_track) {
4550
track_states_per_track.emplace_back(cand);
@@ -51,7 +56,7 @@ TRACCC_HOST_DEVICE inline void fit(
5156
fitter.fit(seed_param, fitter_state);
5257

5358
// Get the final fitting information
54-
track_states[globalIndex].header = fitter_state.m_fit_res;
59+
track_states.at(param_id).header = fitter_state.m_fit_res;
5560
}
5661

5762
} // namespace traccc::device

device/cuda/src/fitting/fitting_algorithm.cu

+33-2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "../utils/cuda_error_handling.hpp"
1010
#include "../utils/utils.hpp"
1111
#include "traccc/cuda/fitting/fitting_algorithm.hpp"
12+
#include "traccc/fitting/device/fill_sort_keys.hpp"
1213
#include "traccc/fitting/device/fit.hpp"
1314
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
1415

@@ -17,24 +18,37 @@
1718
#include "detray/detectors/bfield.hpp"
1819
#include "detray/propagator/rk_stepper.hpp"
1920

21+
// Thrust include(s).
22+
#include <thrust/sort.h>
23+
2024
// System include(s).
2125
#include <vector>
2226

2327
namespace traccc::cuda {
2428

2529
namespace kernels {
2630

31+
__global__ void fill_sort_keys(
32+
track_candidate_container_types::const_view track_candidates_view,
33+
vecmem::data::vector_view<device::sort_key> keys_view,
34+
vecmem::data::vector_view<unsigned int> ids_view) {
35+
36+
device::fill_sort_keys(threadIdx.x + blockIdx.x * blockDim.x,
37+
track_candidates_view, keys_view, ids_view);
38+
}
39+
2740
template <typename fitter_t, typename detector_view_t>
2841
__global__ void fit(
2942
detector_view_t det_data, const typename fitter_t::bfield_type field_data,
3043
const typename fitter_t::config_type cfg,
3144
track_candidate_container_types::const_view track_candidates_view,
45+
vecmem::data::vector_view<const unsigned int> param_ids_view,
3246
track_state_container_types::view track_states_view) {
3347

3448
int gid = threadIdx.x + blockIdx.x * blockDim.x;
3549

3650
device::fit<fitter_t>(gid, det_data, field_data, cfg, track_candidates_view,
37-
track_states_view);
51+
param_ids_view, track_states_view);
3852
}
3953

4054
} // namespace kernels
@@ -82,10 +96,27 @@ track_state_container_types::buffer fitting_algorithm<fitter_t>::operator()(
8296
const unsigned int nThreads = m_warp_size * 2;
8397
const unsigned int nBlocks = (n_tracks + nThreads - 1) / nThreads;
8498

99+
vecmem::data::vector_buffer<device::sort_key> keys_buffer(n_tracks,
100+
m_mr.main);
101+
vecmem::data::vector_buffer<unsigned int> param_ids_buffer(n_tracks,
102+
m_mr.main);
103+
104+
// Get key and value for sorting
105+
kernels::fill_sort_keys<<<nBlocks, nThreads, 0, stream>>>(
106+
track_candidates_view, keys_buffer, param_ids_buffer);
107+
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
108+
109+
// Sort the key to get the sorted parameter ids
110+
vecmem::device_vector<device::sort_key> keys_device(keys_buffer);
111+
vecmem::device_vector<unsigned int> param_ids_device(param_ids_buffer);
112+
113+
thrust::sort_by_key(thrust::cuda::par.on(stream), keys_device.begin(),
114+
keys_device.end(), param_ids_device.begin());
115+
85116
// Run the track fitting
86117
kernels::fit<fitter_t><<<nBlocks, nThreads, 0, stream>>>(
87118
det_view, field_view, m_cfg, track_candidates_view,
88-
track_states_buffer);
119+
param_ids_buffer, track_states_buffer);
89120
TRACCC_CUDA_ERROR_CHECK(cudaGetLastError());
90121
}
91122

device/sycl/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,4 @@ traccc_add_library( traccc_sycl sycl TYPE SHARED
4545
)
4646
target_link_libraries( traccc_sycl
4747
PUBLIC traccc::core detray::core vecmem::core covfie::core
48-
PRIVATE traccc::device_common vecmem::sycl )
48+
PRIVATE traccc::device_common vecmem::sycl oneDPL )

device/sycl/src/fitting/fitting_algorithm.sycl

+40-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
// Project include(s).
99
#include "../utils/get_queue.hpp"
10+
#include "traccc/fitting/device/fill_sort_keys.hpp"
1011
#include "traccc/fitting/device/fit.hpp"
1112
#include "traccc/fitting/kalman_filter/kalman_fitter.hpp"
1213
#include "traccc/sycl/fitting/fitting_algorithm.hpp"
@@ -18,6 +19,10 @@
1819
#include "detray/navigation/navigator.hpp"
1920
#include "detray/propagator/rk_stepper.hpp"
2021

22+
// DPL include(s).
23+
#include <oneapi/dpl/algorithm>
24+
#include <oneapi/dpl/execution>
25+
2126
// System include(s).
2227
#include <vector>
2328

@@ -27,6 +32,9 @@ namespace kernels {
2732
/// Class identifying the kernel running @c
2833
/// traccc::device::fit
2934
class fit;
35+
/// Class identifying the kernel running @c
36+
/// traccc::device::fill_sort_keys
37+
class fill_sort_keys;
3038
} // namespace kernels
3139

3240
template <typename fitter_t>
@@ -74,19 +82,47 @@ track_state_container_types::buffer fitting_algorithm<fitter_t>::operator()(
7482
// (=32)
7583
unsigned int localSize = 64;
7684

85+
vecmem::data::vector_buffer<device::sort_key> keys_buffer(n_tracks,
86+
m_mr.main);
87+
vecmem::data::vector_buffer<unsigned int> param_ids_buffer(n_tracks,
88+
m_mr.main);
89+
vecmem::data::vector_view<device::sort_key> keys_view(keys_buffer);
90+
vecmem::data::vector_view<unsigned int> param_ids_view(param_ids_buffer);
91+
92+
// Sort the key to get the sorted parameter ids
93+
vecmem::device_vector<device::sort_key> keys_device(keys_buffer);
94+
vecmem::device_vector<unsigned int> param_ids_device(param_ids_buffer);
95+
7796
// 1 dim ND Range for the kernel
7897
auto trackParamsNdRange =
7998
traccc::sycl::calculate1DimNdRange(n_tracks, localSize);
8099

100+
details::get_queue(m_queue)
101+
.submit([&](::sycl::handler& h) {
102+
h.parallel_for<kernels::fill_sort_keys>(
103+
trackParamsNdRange, [track_candidates_view, keys_view,
104+
param_ids_view](::sycl::nd_item<1> item) {
105+
device::fill_sort_keys(item.get_global_linear_id(),
106+
track_candidates_view, keys_view,
107+
param_ids_view);
108+
});
109+
})
110+
.wait_and_throw();
111+
112+
oneapi::dpl::sort_by_key(oneapi::dpl::execution::dpcpp_default,
113+
keys_device.begin(), keys_device.end(),
114+
param_ids_device.begin());
115+
81116
details::get_queue(m_queue)
82117
.submit([&](::sycl::handler& h) {
83118
h.parallel_for<kernels::fit>(
84119
trackParamsNdRange,
85120
[det_view, field_view, config = m_cfg, track_candidates_view,
86-
track_states_view](::sycl::nd_item<1> item) {
87-
device::fit<fitter_t>(
88-
item.get_global_linear_id(), det_view, field_view,
89-
config, track_candidates_view, track_states_view);
121+
param_ids_view, track_states_view](::sycl::nd_item<1> item) {
122+
device::fit<fitter_t>(item.get_global_linear_id(), det_view,
123+
field_view, config,
124+
track_candidates_view, param_ids_view,
125+
track_states_view);
90126
});
91127
})
92128
.wait_and_throw();

extern/dpl/CMakeLists.txt

+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# TRACCC library, part of the ACTS project (R&D line)
2+
#
3+
# (c) 2024 CERN for the benefit of the ACTS project
4+
#
5+
# Mozilla Public License Version 2.0
6+
7+
# CMake include(s).
8+
cmake_minimum_required( VERSION 3.14 )
9+
include( FetchContent )
10+
11+
# Silence FetchContent warnings with CMake >=3.24.
12+
if( POLICY CMP0135 )
13+
cmake_policy( SET CMP0135 NEW )
14+
endif()
15+
16+
# Tell the user what's happening.
17+
message( STATUS "Building oneDPL as part of the TRACCC project" )
18+
19+
# Declare where to get DPL from.
20+
set( TRACCC_DPL_SOURCE
21+
"URL;https://github.com/oneapi-src/oneDPL/archive/refs/tags/oneDPL-2022.6.0-rc1.tar.gz;URL_MD5;f52a2ed5c9e4cdb3c65c2465b50abecf"
22+
CACHE STRING "Source for DPL, when built as part of this project" )
23+
mark_as_advanced( TRACCC_DPL_SOURCE )
24+
FetchContent_Declare( DPL ${TRACCC_DPL_SOURCE} )
25+
26+
# Set the default oneDPL threading backend.
27+
set( ONEDPL_BACKEND "dpcpp" CACHE STRING "oneDPL threading backend" )
28+
29+
# Get it into the current directory.
30+
FetchContent_MakeAvailable( DPL )
31+
32+
# Treat the oneDPL headers as "system headers", to avoid getting warnings from
33+
# them.
34+
get_target_property( _incDirs oneDPL INTERFACE_INCLUDE_DIRECTORIES )
35+
target_include_directories( oneDPL
36+
SYSTEM INTERFACE ${_incDirs} )
37+
unset( _incDirs )

extern/dpl/README.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Build Recipe for oneDPL
2+
3+
This directory holds a build recipe for building
4+
[DPL](https://github.com/oneapi-src/oneDPL) for this project.

tests/sycl/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ traccc_add_test(
1515
# TODO: Reactivate this once #655 is fixed.
1616
# test_kalman_fitter_telescope.sycl
1717
test_clusterization.sycl
18+
test_dpl.sycl
1819
test_spacepoint_formation.sycl
1920
test_barrier.sycl
2021
test_mutex.sycl

0 commit comments

Comments
 (0)