Skip to content

Commit

Permalink
Fixes #593: Copy parameters wrapper class work:
Browse files Browse the repository at this point in the history
* Added the missing implementation of `bytes_extent()`
* When setting an endpoint, now also setting the extents, if they have not yet been sent
* `set_source_untyped()` now correctly delegates to `set_endpoint_untyped()`
  • Loading branch information
eyalroz committed Feb 29, 2024
1 parent bebd4c2 commit 4989d48
Showing 1 changed file with 87 additions and 51 deletions.
138 changes: 87 additions & 51 deletions src/cuda/api/copy_parameters.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ struct copy_parameters_t : detail_::base_copy_params_t<NumDimensions> {

this_type& set_source_untyped(context::handle_t context_handle, void *ptr, dimensions_type dimensions)
{
return set_endpoint(endpoint_t::source, context_handle, ptr, dimensions);
return set_endpoint_untyped(endpoint_t::source, context_handle, ptr, dimensions);
}

template<typename T>
Expand Down Expand Up @@ -273,28 +273,14 @@ inline copy_parameters_t<2>& copy_parameters_t<2>::set_endpoint_untyped(
endpoint_t endpoint,
context::handle_t,
void * ptr,
array::dimensions_t<2> dimensions)
{
auto memory_type = memory::type_of(ptr);
if (memory_type == memory::type_t::array) {
throw ::std::invalid_argument("Attempt to use the non-array endpoint setter with array memory at " + cuda::detail_::ptr_as_hex(ptr));
}
if (memory_type == memory::type_t::unified_ or memory_type == type_t::device_)
{
(endpoint == endpoint_t::source ? srcDevice : dstDevice) = device::address(ptr);
}
else {
// Either memory::type_t::host or memory::type_t::non_cuda
if (endpoint == endpoint_t::source) { srcHost = ptr; }
else { dstHost = ptr; }
}
set_bytes_pitch(endpoint, dimensions.width);
(endpoint == endpoint_t::source ? srcMemoryType : dstMemoryType) = static_cast<CUmemorytype>
(memory_type == memory::type_t::non_cuda ? memory::type_t::host_ : memory_type);
// Can't set the endpoint context - the basic data structure doesn't support that!
// (endpoint == endpoint_t::source ? srcContext : dstContext) = context_handle;
return *this;
}
array::dimensions_t<2> dimensions);

template<>
inline copy_parameters_t<3>& copy_parameters_t<3>::set_endpoint_untyped(
endpoint_t endpoint,
context::handle_t,
void * ptr,
array::dimensions_t<3> dimensions);

template<>
template<typename T>
Expand Down Expand Up @@ -365,34 +351,6 @@ copy_parameters_t<3>& copy_parameters_t<3>::set_endpoint(endpoint_t endpoint, co
return *this;
}

template<>
inline copy_parameters_t<3>& copy_parameters_t<3>::set_endpoint_untyped(
endpoint_t endpoint,
context::handle_t context_handle,
void * ptr,
array::dimensions_t<3> dimensions)
{
auto memory_type = memory::type_of(ptr);
if (memory_type == memory::type_t::array) {
throw ::std::invalid_argument("Attempt to use the non-array endpoint setter with array memory at " + cuda::detail_::ptr_as_hex(ptr));
}
if (memory_type == memory::type_t::unified_ or memory_type == type_t::device_)
{
(endpoint == endpoint_t::source ? srcDevice : dstDevice) = device::address(ptr);
}
else {
// Either memory::type_t::host or memory::type_t::non_cuda
if (endpoint == endpoint_t::source) { srcHost = ptr; }
else { dstHost = ptr; }
}
set_bytes_pitch(endpoint, dimensions.width);
(endpoint == endpoint_t::source ? srcHeight : dstHeight) = dimensions.height;
(endpoint == endpoint_t::source ? srcMemoryType : dstMemoryType) = static_cast<CUmemorytype>
(memory_type == memory::type_t::non_cuda ? memory::type_t::host_ : memory_type);
(endpoint == endpoint_t::source ? srcContext : dstContext) = context_handle;
return *this;
}

// 2D copy parameters only have an intra-context variant; should we silently assume the context
// is the same for both ends?
template<>
Expand Down Expand Up @@ -467,6 +425,84 @@ inline copy_parameters_t<3>& copy_parameters_t<3>::set_bytes_extent(dimensions_t
return *this;
}

template<>
inline copy_parameters_t<2>::dimensions_type copy_parameters_t<2>::bytes_extent() const noexcept
{
return copy_parameters_t<2>::dimensions_type { WidthInBytes, Height };
}

template<>
inline copy_parameters_t<3>::dimensions_type copy_parameters_t<3>::bytes_extent() const noexcept
{
return copy_parameters_t<3>::dimensions_type { WidthInBytes, Height, Depth };
}

template<>
inline copy_parameters_t<2>& copy_parameters_t<2>::set_endpoint_untyped(
endpoint_t endpoint,
context::handle_t,
void * ptr,
array::dimensions_t<2> dimensions)
{
auto memory_type = memory::type_of(ptr);
if (memory_type == memory::type_t::array) {
throw ::std::invalid_argument("Attempt to use the non-array endpoint setter with array memory at " + cuda::detail_::ptr_as_hex(ptr));
}
if (memory_type == memory::type_t::unified_ or memory_type == type_t::device_)
{
(endpoint == endpoint_t::source ? srcDevice : dstDevice) = device::address(ptr);
}
else {
// Either memory::type_t::host or memory::type_t::non_cuda
if (endpoint == endpoint_t::source) { srcHost = ptr; }
else { dstHost = ptr; }
}
set_bytes_pitch(endpoint, dimensions.width);
(endpoint == endpoint_t::source ? srcMemoryType : dstMemoryType) = static_cast<CUmemorytype>
(memory_type == memory::type_t::non_cuda ? memory::type_t::host_ : memory_type);
// Can't set the endpoint context - the basic data structure doesn't support that!
// (endpoint == endpoint_t::source ? srcContext : dstContext) = context_handle;

if (bytes_extent().area() == 0) {
set_bytes_extent(dimensions);
}
return *this;
}

template<>
inline copy_parameters_t<3>& copy_parameters_t<3>::set_endpoint_untyped(
endpoint_t endpoint,
context::handle_t context_handle,
void * ptr,
array::dimensions_t<3> dimensions)
{
auto memory_type = memory::type_of(ptr);
if (memory_type == memory::type_t::array) {
throw ::std::invalid_argument("Attempt to use the non-array endpoint setter with array memory at " + cuda::detail_::ptr_as_hex(ptr));
}
if (memory_type == memory::type_t::unified_ or memory_type == type_t::device_)
{
(endpoint == endpoint_t::source ? srcDevice : dstDevice) = device::address(ptr);
}
else {
// Either memory::type_t::host or memory::type_t::non_cuda
if (endpoint == endpoint_t::source) { srcHost = ptr; }
else { dstHost = ptr; }
}
set_bytes_pitch(endpoint, dimensions.width);
(endpoint == endpoint_t::source ? srcHeight : dstHeight) = dimensions.height;
(endpoint == endpoint_t::source ? srcMemoryType : dstMemoryType) = static_cast<CUmemorytype>
(memory_type == memory::type_t::non_cuda ? memory::type_t::host_ : memory_type);
(endpoint == endpoint_t::source ? srcContext : dstContext) = context_handle;

if (bytes_extent().volume() == 0) {
set_bytes_extent(dimensions);
}

return *this;
}


template<>
template<typename T>
copy_parameters_t<3>& copy_parameters_t<3>::set_extent(dimensions_type extent_in_elements) noexcept
Expand Down

0 comments on commit 4989d48

Please sign in to comment.