Skip to content

Commit

Permalink
Fixes #591: Some work on the memory copy functions
Browse files Browse the repository at this point in the history
  • Loading branch information
eyalroz committed Feb 26, 2024
1 parent 8f55efb commit 94ae04a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 4 deletions.
57 changes: 53 additions & 4 deletions src/cuda/api/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -445,13 +445,24 @@ template <typename T, size_t N>
inline void copy(region_t destination, const T(&source)[N])
{
#ifndef NDEBUG
if (destination.size() < N) {
if (destination.size() < N * sizeof(T)) {
throw ::std::logic_error("Source size exceeds destination size");
}
#endif
return copy(destination.start(), source, sizeof(T) * N);
}

template <typename T, size_t N>
inline void copy(span<T> destination, const T(&source)[N])
{
#ifndef NDEBUG
if (destination.size() < N) {
throw ::std::logic_error("Source size exceeds destination size");
}
#endif
return copy(destination.data(), source, sizeof(T) * N);
}

/**
* @param destination A region of memory to which to copy the data in @p source,
* of size at least that of @p source.
Expand All @@ -471,6 +482,19 @@ inline void copy(T(&destination)[N], const_region_t source)
return copy(destination, source.start(), sizeof(T) * N);
}

template <typename T, size_t N>
inline void copy(T(&destination)[N], span<T const> source)
{
#ifndef NDEBUG
if (source.size() > N) {
throw ::std::invalid_argument(
"Attempt to copy a span of " + ::std::to_string(source.size()) +
" elements into an array of " + ::std::to_string(N) + " elements");
}
#endif
return copy(destination, source.start(), sizeof(T) * N);
}

template <typename T, size_t N>
inline void copy(void* destination, T (&source)[N])
{
Expand Down Expand Up @@ -669,6 +693,19 @@ void copy(const array_t<T, NumDimensions>& destination, const T *source)
copy(destination, context_of(source), source);
}

template<typename T, dimensionality_t NumDimensions>
void copy(const array_t<T, NumDimensions>& destination, span<T const> source)
{
#ifndef NDEBUG
if (destination.size() < source.size()) {
throw ::std::invalid_argument(
"Attempt to copy a span of " + ::std::to_string(source.size()) +
" elements into a CUDA array of " + ::std::to_string(destination.size()) + " elements");
}
#endif
copy(destination, source.data());
}

/**
* Synchronously copies data into a CUDA array from non-array memory.
*
Expand Down Expand Up @@ -710,6 +747,19 @@ void copy(T *destination, const array_t<T, NumDimensions>& source)
copy(context_of(destination), destination, source);
}

template <typename T, dimensionality_t NumDimensions>
void copy(span<T> destination, const array_t<T, NumDimensions>& source)
{
#ifndef NDEBUG
if (destination.size() < source.size()) {
throw ::std::invalid_argument(
"Attempt to copy a CUDA array of " + ::std::to_string(source.size()) +
" elements into a span of " + ::std::to_string(destination.size()) + " elements");
}
#endif
copy(destination.data(), source);
}

template <typename T, dimensionality_t NumDimensions>
void copy(const array_t<T, NumDimensions>& destination, const array_t<T, NumDimensions>& source)
{
Expand Down Expand Up @@ -742,7 +792,7 @@ void copy(const array_t<T, NumDimensions>& destination, const_region_t source)
if (destination.size_bytes() < source.size()) {
throw ::std::logic_error("Attempt to copy into an array from a source region larger than the array's size");
}
copy(destination, source.start());
copy(destination, static_cast<T const*>(source.start()));
}

/**
Expand Down Expand Up @@ -854,7 +904,6 @@ status_t multidim_copy(
return multidim_copy_in_current_context(::std::integral_constant<dimensionality_t, NumDimensions>{}, params, stream_handle);
}


// Assumes the array and the stream share the same context, and that the destination is
// accessible from that context (e.g. allocated within it, or being managed memory, etc.)
template <typename T, dimensionality_t NumDimensions>
Expand Down Expand Up @@ -1021,7 +1070,7 @@ void copy(array_t<T, NumDimensions>& destination, const_region_t source, const s
" bytes into an array of size " + ::std::to_string(required_size) + " bytes");
}
#endif
copy(destination, source.start(), stream);
copy(destination, static_cast<T const*>(source.start()), stream);
}

/**
Expand Down
26 changes: 26 additions & 0 deletions src/cuda/api/multi_wrapper_impls/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ inline void copy(array_t<T, NumDimensions>& destination, const T* source, const
detail_::copy<T, NumDimensions>(destination, source, stream.handle());
}

template <typename T, dimensionality_t NumDimensions>
inline void copy(array_t<T, NumDimensions>& destination, span<T const> source, const stream_t& stream)
{
#ifndef NDEBUG
if (source.size() != destination.size()) {
throw ::std::invalid_argument(
"Attempt to copy " + ::std::to_string(source.size()) +
" elements into an array of " + ::std::to_string(destination.size()) + " elements");
}
#endif
detail_::copy<T, NumDimensions>(destination, source.data(), stream.handle());
}

// Note: Assumes the destination, source and stream are all usable on the same content
template <typename T, dimensionality_t NumDimensions>
inline void copy(T* destination, const array_t<T, NumDimensions>& source, const stream_t& stream)
Expand All @@ -55,6 +68,19 @@ inline void copy(T* destination, const array_t<T, NumDimensions>& source, const
detail_::copy<T, NumDimensions>(destination, source, stream.handle());
}

template <typename T, dimensionality_t NumDimensions>
inline void copy(span<T> destination, const array_t<T, NumDimensions>& source, const stream_t& stream)
{
#ifndef NDEBUG
if (destination.size() != source.size()) {
throw ::std::invalid_argument(
"Attempt to copy " + ::std::to_string(source.size()) +
" elements into an array of " + ::std::to_string(destination.size()) + " elements");
}
#endif
copy(destination.data(), source, stream);
}

template <typename T>
inline void copy_single(T& destination, const T& source, const stream_t& stream)
{
Expand Down

0 comments on commit 94ae04a

Please sign in to comment.