Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SYCL Track Finding in the Example Executables, main branch (2025.01.09.) #812

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 131 additions & 120 deletions README.md

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions device/sycl/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ traccc_add_library( traccc_sycl sycl TYPE SHARED
# Clusterization algorithm(s).
"include/traccc/sycl/clusterization/clusterization_algorithm.hpp"
"src/clusterization/clusterization_algorithm.sycl"
"include/traccc/sycl/clusterization/measurement_sorting_algorithm.hpp"
"src/clusterization/measurement_sorting_algorithm.sycl"
# Seeding algorithm(s).
"include/traccc/sycl/seeding/silicon_pixel_spacepoint_formation_algorithm.hpp"
"src/seeding/silicon_pixel_spacepoint_formation_algorithm.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2025 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

#pragma once

// Local include(s).
#include "traccc/sycl/utils/queue_wrapper.hpp"

// Project include(s).
#include "traccc/edm/measurement.hpp"
#include "traccc/utils/algorithm.hpp"

// VecMem include(s).
#include <vecmem/utils/copy.hpp>

// System include(s).
#include <functional>

namespace traccc::sycl {

/// Algorithm sorting the reconstructed measurements in their container
///
/// The track finding algorithm expects measurements belonging to a single
/// detector module to be consecutive in memory. But
/// @c traccc::sycl::clusterization_algorithm does not (currently) produce the
/// measurements in such an ordered state. This is where this algorithm comes
/// to the rescue.
///
class measurement_sorting_algorithm
: public algorithm<measurement_collection_types::view(
const measurement_collection_types::view&)> {

public:
/// Constructor for the algorithm
///
/// @param copy The copy object to use in the algorithm
/// @param queue Wrapper for the for the SYCL queue for kernel invocation
///
measurement_sorting_algorithm(vecmem::copy& copy, queue_wrapper& queue);

/// Callable operator performing the sorting on a container
///
/// @param measurements The measurements to sort
///
output_type operator()(const measurement_collection_types::view&
measurements_view) const override;

private:
/// Copy object to use in the algorithm
std::reference_wrapper<vecmem::copy> m_copy;
/// The SYCL queue to use
std::reference_wrapper<queue_wrapper> m_queue;

}; // class measurement_sorting_algorithm

} // namespace traccc::sycl
52 changes: 52 additions & 0 deletions device/sycl/src/clusterization/measurement_sorting_algorithm.sycl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2025 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/

// Local include(s).
#include "../utils/get_queue.hpp"
#include "traccc/sycl/clusterization/measurement_sorting_algorithm.hpp"

// oneDPL include(s).
#include <oneapi/dpl/algorithm>
#include <oneapi/dpl/execution>

// SYCL include(s).
#include <sycl/sycl.hpp>

namespace traccc::sycl {

measurement_sorting_algorithm::measurement_sorting_algorithm(
vecmem::copy& copy, queue_wrapper& queue)
: m_copy{copy}, m_queue{queue} {}

measurement_sorting_algorithm::output_type
measurement_sorting_algorithm::operator()(
const measurement_collection_types::view& measurements_view) const {

// Get the SYCL queue to use for the algorithm.
::sycl::queue& queue = details::get_queue(m_queue.get());

// oneDPL policy to use, forcing execution onto the same device that the
// hand-written kernels would run on.
auto policy = oneapi::dpl::execution::device_policy{queue};

// Get the number of measurements. This is necessary because the input
// container may not be fixed sized. And we can't give invalid pointers /
// iterators to oneDPL.
const measurement_collection_types::view::size_type n_measurements =
m_copy.get().get_size(measurements_view);

// Sort the measurements in place
oneapi::dpl::sort(policy, measurements_view.ptr(),
measurements_view.ptr() + n_measurements,
measurement_sort_comp());
queue.wait_and_throw();

// Return the view of the sorted measurements.
return measurements_view;
}

} // namespace traccc::sycl
32 changes: 28 additions & 4 deletions examples/run/sycl/full_chain_algorithm.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
/** TRACCC library, part of the ACTS project (R&D line)
*
* (c) 2022-2024 CERN for the benefit of the ACTS project
* (c) 2022-2025 CERN for the benefit of the ACTS project
*
* Mozilla Public License Version 2.0
*/
Expand All @@ -9,11 +9,12 @@

// Project include(s).
#include "traccc/edm/silicon_cell_collection.hpp"
#include "traccc/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/fitting/kalman_fitting_algorithm.hpp"
#include "traccc/geometry/detector.hpp"
#include "traccc/geometry/silicon_detector_description.hpp"
#include "traccc/sycl/clusterization/clusterization_algorithm.hpp"
#include "traccc/sycl/clusterization/measurement_sorting_algorithm.hpp"
#include "traccc/sycl/finding/combinatorial_kalman_filter_algorithm.hpp"
#include "traccc/sycl/seeding/seeding_algorithm.hpp"
#include "traccc/sycl/seeding/silicon_pixel_spacepoint_formation_algorithm.hpp"
#include "traccc/sycl/seeding/track_params_estimation.hpp"
Expand Down Expand Up @@ -72,7 +73,7 @@ class full_chain_algorithm
using clustering_algorithm = clusterization_algorithm;
/// Track finding algorithm type
using finding_algorithm =
traccc::host::combinatorial_kalman_filter_algorithm;
traccc::sycl::combinatorial_kalman_filter_algorithm;
/// Track fitting algorithm type
using fitting_algorithm = traccc::host::kalman_fitting_algorithm;

Expand Down Expand Up @@ -126,6 +127,11 @@ class full_chain_algorithm
/// Memory copy object
mutable vecmem::sycl::async_copy m_copy;

/// Constant B field for the (seed) track parameter estimation
traccc::vector3 m_field_vec;
/// Constant B field for the track finding and fitting
detray::bfield::const_field_t m_field;

/// Detector description
std::reference_wrapper<const silicon_detector_description::host>
m_det_descr;
Expand All @@ -140,21 +146,39 @@ class full_chain_algorithm

/// @name Sub-algorithms used by this full-chain algorithm
/// @{

/// Clusterization algorithm
clusterization_algorithm m_clusterization;
/// Measurement sorting algorithm
measurement_sorting_algorithm m_measurement_sorting;
/// Spacepoint formation algorithm
spacepoint_formation_algorithm m_spacepoint_formation;
/// Seeding algorithm
seeding_algorithm m_seeding;
/// Track parameter estimation algorithm
track_params_estimation m_track_parameter_estimation;
/// Track finding algorithm
finding_algorithm m_finding;

/// Configs
/// @}

/// @}

/// @name Algorithm configurations
/// @{

/// Configuration for clustering
clustering_config m_clustering_config;
/// Configuration for the seed finding
seedfinder_config m_finder_config;
/// Configuration for the spacepoint grid formation
spacepoint_grid_config m_grid_config;
/// Configuration for the seed filtering
seedfilter_config m_filter_config;

/// Configuration for the track finding
finding_algorithm::config_type m_finding_config;

/// @}

}; // class full_chain_algorithm
Expand Down
28 changes: 24 additions & 4 deletions examples/run/sycl/full_chain_algorithm.sycl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ full_chain_algorithm::full_chain_algorithm(
const seedfinder_config& finder_config,
const spacepoint_grid_config& grid_config,
const seedfilter_config& filter_config,
const finding_algorithm::config_type&,
const finding_algorithm::config_type& finding_config,
const fitting_algorithm::config_type&,
const silicon_detector_description::host& det_descr,
host_detector_type* detector)
Expand All @@ -64,6 +64,8 @@ full_chain_algorithm::full_chain_algorithm(
m_device_mr{&(m_data->m_queue)},
m_cached_device_mr{m_device_mr},
m_copy{&(m_data->m_queue)},
m_field_vec{0.f, 0.f, finder_config.bFieldInZ},
m_field{detray::bfield::create_const_field(m_field_vec)},
m_det_descr(det_descr),
m_device_det_descr{
static_cast<silicon_detector_description::buffer::size_type>(
Expand All @@ -73,6 +75,7 @@ full_chain_algorithm::full_chain_algorithm(
m_device_detector{},
m_clusterization{memory_resource{m_cached_device_mr, &(m_host_mr.get())},
m_copy, m_data->m_queue_wrapper, clustering_config},
m_measurement_sorting(m_copy, m_data->m_queue_wrapper),
m_spacepoint_formation{
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
m_data->m_queue_wrapper},
Expand All @@ -85,10 +88,14 @@ full_chain_algorithm::full_chain_algorithm(
m_track_parameter_estimation{
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
m_data->m_queue_wrapper},
m_finding{finding_config,
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
m_data->m_queue_wrapper},
m_clustering_config(clustering_config),
m_finder_config(finder_config),
m_grid_config(grid_config),
m_filter_config(filter_config) {
m_filter_config(filter_config),
m_finding_config(finding_config) {

// Tell the user what device is being used.
std::cout
Expand All @@ -112,6 +119,8 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
m_device_mr{&(m_data->m_queue)},
m_cached_device_mr{m_device_mr},
m_copy{&(m_data->m_queue)},
m_field_vec{parent.m_field_vec},
m_field{parent.m_field},
m_det_descr(parent.m_det_descr),
m_device_det_descr{
static_cast<silicon_detector_description::buffer::size_type>(
Expand All @@ -122,6 +131,7 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
m_clusterization{memory_resource{m_cached_device_mr, &(m_host_mr.get())},
m_copy, m_data->m_queue_wrapper,
parent.m_clustering_config},
m_measurement_sorting(m_copy, m_data->m_queue_wrapper),
m_spacepoint_formation{
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
m_data->m_queue_wrapper},
Expand All @@ -134,10 +144,14 @@ full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
m_track_parameter_estimation{
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
m_data->m_queue_wrapper},
m_finding{parent.m_finding_config,
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
m_data->m_queue_wrapper},
m_clustering_config(parent.m_clustering_config),
m_finder_config(parent.m_finder_config),
m_grid_config(parent.m_grid_config),
m_filter_config(parent.m_filter_config) {
m_filter_config(parent.m_filter_config),
m_finding_config(parent.m_finding_config) {

// Copy the detector (description) to the device.
m_copy(vecmem::get_data(m_det_descr.get()), m_device_det_descr)->wait();
Expand All @@ -161,21 +175,27 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
// Execute the algorithms.
const clusterization_algorithm::output_type measurements =
m_clusterization(cells_buffer, m_device_det_descr);
m_measurement_sorting(measurements);

// If we have a Detray detector, run the seeding, track
// finding and fitting.
if (m_detector != nullptr) {

// Run the seed-finding.
const spacepoint_formation_algorithm::output_type spacepoints =
m_spacepoint_formation(m_device_detector_view, measurements);
const track_params_estimation::output_type track_params =
m_track_parameter_estimation(spacepoints, m_seeding(spacepoints),
{0.f, 0.f, m_finder_config.bFieldInZ});

// Run the track finding.
const finding_algorithm::output_type track_candidates = m_finding(
m_device_detector_view, m_field, measurements, track_params);

// Get the final data back to the host.
bound_track_parameters_collection_types::host result(
&(m_host_mr.get()));
m_copy(track_params, result)->wait();
m_copy(track_candidates.headers, result)->wait();

// Return the host container.
return result;
Expand Down
Loading
Loading