Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

traccc::sycl::full_chain_algorithm Fixes, main branch (2024.11.12.) #769

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/run/sycl/full_chain_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,13 @@ class full_chain_algorithm

private:
/// Private data object
details::full_chain_algorithm_data* m_data;
std::unique_ptr<details::full_chain_algorithm_data> m_data;
/// Host memory resource
vecmem::memory_resource& m_host_mr;
std::reference_wrapper<vecmem::memory_resource> m_host_mr;
/// Device memory resource
std::unique_ptr<vecmem::sycl::device_memory_resource> m_device_mr;
vecmem::sycl::device_memory_resource m_device_mr;
/// Device caching memory resource
std::unique_ptr<vecmem::binary_page_memory_resource> m_cached_device_mr;
mutable vecmem::binary_page_memory_resource m_cached_device_mr;
/// Memory copy object
mutable vecmem::sycl::async_copy m_copy;

Expand Down
126 changes: 71 additions & 55 deletions examples/run/sycl/full_chain_algorithm.sycl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "full_chain_algorithm.hpp"

// SYCL include(s).
#include <CL/sycl.hpp>
#include <sycl/sycl.hpp>

// System include(s).
#include <exception>
Expand All @@ -35,7 +35,14 @@ namespace traccc::sycl {
namespace details {

struct full_chain_algorithm_data {

/// Constructor
explicit full_chain_algorithm_data(const ::sycl::async_handler& handler)
: m_queue{handler} {}

/// The native SYCL queue object
::sycl::queue m_queue;
/// Wrapper around the SYCL queue object
queue_wrapper m_queue_wrapper{&m_queue};
};

Expand All @@ -51,29 +58,33 @@ full_chain_algorithm::full_chain_algorithm(
const fitting_algorithm::config_type&,
const silicon_detector_description::host& det_descr,
host_detector_type* detector)
: m_data(new details::full_chain_algorithm_data{{::handle_async_error}}),
: m_data(std::make_unique<details::full_chain_algorithm_data>(
::handle_async_error)),
m_host_mr(host_mr),
m_device_mr(std::make_unique<vecmem::sycl::device_memory_resource>(
&(m_data->m_queue))),
m_cached_device_mr(
std::make_unique<vecmem::binary_page_memory_resource>(*m_device_mr)),
m_copy(&(m_data->m_queue)),
m_device_mr{&(m_data->m_queue)},
m_cached_device_mr{m_device_mr},
m_copy{&(m_data->m_queue)},
m_det_descr(det_descr),
m_device_det_descr(
m_device_det_descr{
static_cast<silicon_detector_description::buffer::size_type>(
m_det_descr.get().size()),
*m_device_mr),
m_device_mr},
m_detector(detector),
m_clusterization(memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
m_data->m_queue_wrapper, clustering_config),
m_spacepoint_formation(memory_resource{*m_cached_device_mr, &m_host_mr},
m_copy, m_data->m_queue_wrapper),
m_seeding(finder_config, grid_config, filter_config,
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
&(m_data->m_queue)),
m_track_parameter_estimation(
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
&(m_data->m_queue)),
m_device_detector{},
m_clusterization{memory_resource{m_cached_device_mr, &(m_host_mr.get())},
m_copy, m_data->m_queue_wrapper, clustering_config},
m_spacepoint_formation{
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
m_data->m_queue_wrapper},
m_seeding{finder_config,
grid_config,
filter_config,
memory_resource{m_cached_device_mr, &(m_host_mr.get())},
m_copy,
m_data->m_queue_wrapper},
m_track_parameter_estimation{
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
m_data->m_queue_wrapper},
m_clustering_config(clustering_config),
m_finder_config(finder_config),
m_grid_config(grid_config),
Expand All @@ -85,71 +96,74 @@ full_chain_algorithm::full_chain_algorithm(
<< m_data->m_queue.get_device().get_info<::sycl::info::device::name>()
<< std::endl;

// Creating SYCL queue object
::sycl::queue q(handle_async_error);
traccc::sycl::queue_wrapper queue{&q};

// Copy the detector (description) to the device.
m_copy(vecmem::get_data(m_det_descr.get()), m_device_det_descr)->wait();
if (m_detector != nullptr) {
m_device_detector = detray::get_buffer(detray::get_data(*m_detector),
*m_device_mr, m_copy);
q.wait_and_throw();
m_device_mr, m_copy);
m_device_detector_view = detray::get_data(m_device_detector);
}
}

full_chain_algorithm::full_chain_algorithm(const full_chain_algorithm& parent)
: m_data(new details::full_chain_algorithm_data{{::handle_async_error}}),
: m_data(std::make_unique<details::full_chain_algorithm_data>(
::handle_async_error)),
m_host_mr(parent.m_host_mr),
m_device_mr(std::make_unique<vecmem::sycl::device_memory_resource>(
&(m_data->m_queue))),
m_cached_device_mr(
std::make_unique<vecmem::binary_page_memory_resource>(*m_device_mr)),
m_copy(&(m_data->m_queue)),
m_device_mr{&(m_data->m_queue)},
m_cached_device_mr{m_device_mr},
m_copy{&(m_data->m_queue)},
m_det_descr(parent.m_det_descr),
m_device_det_descr(
m_device_det_descr{
static_cast<silicon_detector_description::buffer::size_type>(
m_det_descr.get().size()),
*m_device_mr),
m_clusterization(memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
m_data->m_queue_wrapper, parent.m_clustering_config),
m_spacepoint_formation(memory_resource{*m_cached_device_mr, &m_host_mr},
m_copy, m_data->m_queue_wrapper),
m_seeding(parent.m_finder_config, parent.m_grid_config,
m_device_mr},
m_detector(parent.m_detector),
m_device_detector{},
m_clusterization{memory_resource{m_cached_device_mr, &(m_host_mr.get())},
m_copy, m_data->m_queue_wrapper,
parent.m_clustering_config},
m_spacepoint_formation{
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
m_data->m_queue_wrapper},
m_seeding{parent.m_finder_config,
parent.m_grid_config,
parent.m_filter_config,
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
&(m_data->m_queue)),
m_track_parameter_estimation(
memory_resource{*m_cached_device_mr, &m_host_mr}, m_copy,
&(m_data->m_queue)),
memory_resource{m_cached_device_mr, &(m_host_mr.get())},
m_copy,
m_data->m_queue_wrapper},
m_track_parameter_estimation{
memory_resource{m_cached_device_mr, &(m_host_mr.get())}, m_copy,
m_data->m_queue_wrapper},
m_clustering_config(parent.m_clustering_config),
m_finder_config(parent.m_finder_config),
m_grid_config(parent.m_grid_config),
m_filter_config(parent.m_filter_config) {

// Copy the detector (description) to the device.
m_copy(vecmem::get_data(m_det_descr.get()), m_device_det_descr)->wait();
if (m_detector != nullptr) {
m_device_detector = detray::get_buffer(detray::get_data(*m_detector),
m_device_mr, m_copy);
m_device_detector_view = detray::get_data(m_device_detector);
}
}

full_chain_algorithm::~full_chain_algorithm() {
// Need to ensure that objects would be deleted in the correct order.
delete m_data;
}
full_chain_algorithm::~full_chain_algorithm() = default;

full_chain_algorithm::output_type full_chain_algorithm::operator()(
const edm::silicon_cell_collection::host& cells) const {

// Create device copy of input collections
edm::silicon_cell_collection::buffer cells_buffer(
static_cast<unsigned int>(cells.size()), *m_cached_device_mr);
edm::silicon_cell_collection::buffer cells_buffer{
static_cast<unsigned int>(cells.size()), m_cached_device_mr};
m_copy(vecmem::get_data(cells), cells_buffer)->wait();

// Execute the algorithms.
const clusterization_algorithm::output_type measurements =
m_clusterization(cells_buffer, m_device_det_descr);

// If we have a Detray detector, run the seeding, track finding and fitting.
// If we have a Detray detector, run the seeding, track
// finding and fitting.
if (m_detector != nullptr) {

const spacepoint_formation_algorithm::output_type spacepoints =
Expand All @@ -159,18 +173,20 @@ full_chain_algorithm::output_type full_chain_algorithm::operator()(
{0.f, 0.f, m_finder_config.bFieldInZ});

// Get the final data back to the host.
bound_track_parameters_collection_types::host result(&m_host_mr);
(m_copy)(track_params, result)->wait();
bound_track_parameters_collection_types::host result(
&(m_host_mr.get()));
m_copy(track_params, result)->wait();

// Return the host container.
return result;
}
// If not, copy the measurements back to the host, and return a dummy
// object.
// If not, copy the measurements back to the host, and return
// a dummy object.
else {

// Copy the measurements back to the host.
measurement_collection_types::host measurements_host(&m_host_mr);
measurement_collection_types::host measurements_host(
&(m_host_mr.get()));
m_copy(measurements, measurements_host)->wait();

// Return an empty object.
Expand Down
Loading