Skip to content

Commit

Permalink
Merge pull request #2970 from gangliao/memcpy
Browse files Browse the repository at this point in the history
Add CPU/GPU Memcpy in memory folder
  • Loading branch information
wangkuiyi authored Jul 21, 2017
2 parents 6129ab4 + 6cae35b commit e1140f2
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 25 deletions.
45 changes: 44 additions & 1 deletion paddle/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ limitations under the License. */
#include "paddle/memory/memory.h"
#include "paddle/memory/detail/buddy_allocator.h"
#include "paddle/memory/detail/system_allocator.h"
#include "paddle/platform/assert.h"

#include <cstring> // for memcpy

namespace paddle {
namespace memory {
Expand Down Expand Up @@ -45,6 +46,13 @@ size_t Used<platform::CPUPlace>(platform::CPUPlace place) {
return GetCPUBuddyAllocator()->Used();
}

template <>
void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
platform::CPUPlace,
const void* src, size_t num) {
std::memcpy(dst, src, num);
}

#ifndef PADDLE_ONLY_CPU

detail::BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
Expand Down Expand Up @@ -77,6 +85,41 @@ size_t Used<platform::GPUPlace>(platform::GPUPlace place) {
return GetGPUBuddyAllocator(place.device)->Used();
}

template <>
void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
void* dst,
platform::GPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
}

template <>
void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
void* dst,
platform::CPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
}

template <>
void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
void* dst,
platform::GPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
if (dst_place == src_place) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
} else {
platform::GpuMemcpyPeer(dst, dst_place.device, src, src_place.device, num,
stream);
}
}

#endif // PADDLE_ONLY_CPU

} // namespace memory
Expand Down
16 changes: 13 additions & 3 deletions paddle/memory/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,30 @@ limitations under the License. */

#pragma once

#include "paddle/platform/gpu_info.h"
#include "paddle/platform/place.h"

namespace paddle {
namespace memory {

template <class Place>
template <typename Place>
void* Alloc(Place, size_t);

template <class Place>
template <typename Place>
void Free(Place, void*);

template <class Place>
template <typename Place>
size_t Used(Place);

template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num);

#ifndef PADDLE_ONLY_CPU
template <typename DstPlace, typename SrcPlace>
void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num,
cudaStream_t stream);
#endif // PADDLE_ONLY_CPU

template <typename T, /* must be POD types */
typename Place /* platform::GPUPlace or platform::CPUPlace */,
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
Expand Down
45 changes: 26 additions & 19 deletions paddle/platform/enforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,26 @@ namespace platform {
// For more details, please check https://stackoverflow.com/a/43870188/724872.
#define UNLIKELY(condition) __builtin_expect(static_cast<bool>(condition), 0)

template <typename T>
inline void throw_on_error(T e) {
throw_on_error(e, "");
}

template <typename... Args>
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
int stat, const Args&... args) {
if (UNLIKELY(!(stat))) {
throw std::runtime_error(
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
}
}

#ifndef PADDLE_ONLY_CPU

template <typename... Args>
inline void throw_on_error(cudaError_t e, const Args&... args) {
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudaError_t e, const Args&... args) {
if (UNLIKELY(e)) {
// clang-format off
throw thrust::system_error(
Expand All @@ -58,7 +74,8 @@ inline void throw_on_error(cudaError_t e, const Args&... args) {
}

template <typename... Args>
inline void throw_on_error(curandStatus_t stat, const Args&... args) {
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
curandStatus_t stat, const Args&... args) {
if (stat != CURAND_STATUS_SUCCESS) {
// clang-format off
throw thrust::system_error(
Expand All @@ -70,7 +87,8 @@ inline void throw_on_error(curandStatus_t stat, const Args&... args) {
}

template <typename... Args>
inline void throw_on_error(cudnnStatus_t stat, const Args&... args) {
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cudnnStatus_t stat, const Args&... args) {
if (stat == CUDNN_STATUS_SUCCESS) {
return;
} else {
Expand All @@ -84,7 +102,8 @@ inline void throw_on_error(cudnnStatus_t stat, const Args&... args) {
}

template <typename... Args>
inline void throw_on_error(cublasStatus_t stat, const Args&... args) {
inline typename std::enable_if<sizeof...(Args) != 0, void>::type throw_on_error(
cublasStatus_t stat, const Args&... args) {
std::string err;
if (stat == CUBLAS_STATUS_SUCCESS) {
return;
Expand Down Expand Up @@ -113,28 +132,16 @@ inline void throw_on_error(cublasStatus_t stat, const Args&... args) {

#endif // PADDLE_ONLY_CPU

template <typename... Args>
inline void throw_on_error(int stat, const Args&... args) {
if (UNLIKELY(!(stat))) {
throw std::runtime_error(
string::Sprintf(args...) +
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__));
}
}

#define PADDLE_THROW(...) \
do { \
throw std::runtime_error( \
string::Sprintf(__VA_ARGS__) + \
string::Sprintf(" at [%s:%s];", __FILE__, __LINE__)); \
} while (0)

/**
* @brief Enforce a condition, otherwise throw an EnforceNotMet
*/
#define PADDLE_ENFORCE(condition, ...) \
do { \
::paddle::platform::throw_on_error(condition, __VA_ARGS__); \
#define PADDLE_ENFORCE(...) \
do { \
::paddle::platform::throw_on_error(__VA_ARGS__); \
} while (0)

} // namespace platform
Expand Down
25 changes: 24 additions & 1 deletion paddle/platform/gpu_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ void SetDeviceId(int id) {
"cudaSetDevice failed in paddle::platform::SetDeviceId");
}

void GpuMemoryUsage(size_t& available, size_t& total) {
void GpuMemoryUsage(size_t &available, size_t &total) {
PADDLE_ENFORCE(cudaMemGetInfo(&available, &total),
"cudaMemGetInfo failed in paddle::platform::GetMemoryUsage");
}
Expand Down Expand Up @@ -82,5 +82,28 @@ size_t GpuMaxChunkSize() {
return usable;
}

void GpuMemcpyAsync(void *dst, const void *src, size_t count,
enum cudaMemcpyKind kind, cudaStream_t stream) {
PADDLE_ENFORCE(cudaMemcpyAsync(dst, src, count, kind, stream),
"cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync");
}

void GpuMemcpySync(void *dst, const void *src, size_t count,
enum cudaMemcpyKind kind) {
PADDLE_ENFORCE(cudaMemcpy(dst, src, count, kind),
"cudaMemcpy failed in paddle::platform::GpuMemcpySync");
// note: cudaMemcpy may actually be asynchronous with respect to the caller,
// block on stream 0 to make sure the copy has completed
PADDLE_ENFORCE(
cudaStreamSynchronize(0),
"cudaStreamSynchronize failed in paddle::platform::GpuMemcpySync");
}

void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
size_t count, cudaStream_t stream) {
PADDLE_ENFORCE(
cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream),
"cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeer");
}
} // namespace platform
} // namespace paddle
15 changes: 14 additions & 1 deletion paddle/platform/gpu_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License. */

#ifndef PADDLE_ONLY_CPU

#include <cuda_runtime.h>
#include <stddef.h>

namespace paddle {
Expand All @@ -31,7 +32,7 @@ int GetCurrentDeviceId();
void SetDeviceId(int device_id);

//!Get the memory usage of current GPU device.
void GpuMemoryUsage(size_t& available, size_t& total);
void GpuMemoryUsage(size_t &available, size_t &total);

//! Get the maximum allocation size of current GPU device.
size_t GpuMaxAllocSize();
Expand All @@ -42,6 +43,18 @@ size_t GpuMinChunkSize();
//! Get the maximum chunk size for GPU buddy allocator.
size_t GpuMaxChunkSize();

//! Copy memory from address src to dst asynchronously.
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
enum cudaMemcpyKind kind, cudaStream_t stream);

//! Copy memory from address src to dst synchronously.
void GpuMemcpySync(void *dst, const void *src, size_t count,
enum cudaMemcpyKind kind);

//! Copy memory from one device to another device.
void GpuMemcpyPeer(void *dst, int dst_device, const void *src, int src_device,
size_t count, cudaStream_t stream);

} // namespace platform
} // namespace paddle

Expand Down

0 comments on commit e1140f2

Please sign in to comment.