Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CUDAPinnedPlace #9380

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions paddle/fluid/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ class Tensor {
friend struct EigenVector;

public:
Tensor() : offset_(0), is_pinned_(false) {}
Tensor() : offset_(0) {}

/*! Constructor with place should only be used in pybind. */
explicit Tensor(const platform::Place& place)
: offset_(0), is_pinned_(false) {
explicit Tensor(const platform::Place& place) : offset_(0) {
holder_->set_place(place);
}

Expand All @@ -70,12 +69,11 @@ class Tensor {
* @note If not exist, then allocation.
*/
template <typename T>
inline T* mutable_data(platform::Place place, bool is_pinned = false);
inline T* mutable_data(platform::Place place);

inline void* mutable_data(platform::Place place, std::type_index type,
bool is_pinned = false);
inline void* mutable_data(platform::Place place, std::type_index type);

inline void* mutable_data(platform::Place place, bool is_pinned = false);
inline void* mutable_data(platform::Place place);

/**
* @brief Return a pointer to mutable memory block.
Expand All @@ -86,18 +84,14 @@ class Tensor {
* @note If not exist, then allocation.
*/
template <typename T>
inline T* mutable_data(DDim dims, platform::Place place,
bool is_pinned = false);
inline T* mutable_data(DDim dims, platform::Place place);

/*! Return the dimensions of the memory block. */
inline const DDim& dims() const;

/*! Return the numel of the memory block. */
inline int64_t numel() const;

/*! Return the numel of the memory block. */
inline bool isPinned() const;

/*! Resize the dimensions of the memory block. */
inline Tensor& Resize(const DDim& dims);

Expand Down Expand Up @@ -152,14 +146,12 @@ class Tensor {

template <typename Place>
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(Place place, size_t size, std::type_index type,
bool is_pinned = false)
: ptr_(static_cast<uint8_t*>(memory::Alloc(place, size, is_pinned)),
memory::PODDeleter<uint8_t, Place>(place, is_pinned)),
PlaceholderImpl(Place place, size_t size, std::type_index type)
: ptr_(static_cast<uint8_t*>(memory::Alloc(place, size)),
memory::PODDeleter<uint8_t, Place>(place)),
place_(place),
size_(size),
type_(type),
is_pinned_(is_pinned) {
type_(type) {
PADDLE_ENFORCE_NOT_NULL(ptr_, "Insufficient %s memory to allocation.",
(is_cpu_place(place_) ? "CPU" : "GPU"));
}
Expand All @@ -182,9 +174,6 @@ class Tensor {

/* the current type of memory */
std::type_index type_;

/*! use pinned memory or not. */
bool is_pinned_;
};

/*! holds the memory block if allocated. */
Expand Down Expand Up @@ -219,7 +208,6 @@ class Tensor {
* PlaceHolder::ptr_ and where the tensor data really begins.
*/
size_t offset_;
bool is_pinned_;
};

inline void Tensor::switch_place(platform::Place new_place) {
Expand Down
23 changes: 9 additions & 14 deletions paddle/fluid/framework/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,19 @@ inline T* Tensor::data() {
}

template <typename T>
inline T* Tensor::mutable_data(DDim dims, platform::Place place,
bool is_pinned) {
inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims);
return mutable_data<T>(place, is_pinned);
return mutable_data<T>(place);
}

template <typename T>
inline T* Tensor::mutable_data(platform::Place place, bool is_pinned) {
inline T* Tensor::mutable_data(platform::Place place) {
static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>(mutable_data(place, typeid(T), is_pinned));
return reinterpret_cast<T*>(mutable_data(place, typeid(T)));
}

inline void* Tensor::mutable_data(platform::Place place, std::type_index type,
bool is_pinned) {
inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
if (holder_ != nullptr) {
holder_->set_type(type);
}
Expand All @@ -129,27 +127,26 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type,
holder_->size() < size + offset_) {
if (platform::is_cpu_place(place)) {
holder_.reset(new PlaceholderImpl<platform::CPUPlace>(
boost::get<platform::CPUPlace>(place), size, type, is_pinned));
boost::get<platform::CPUPlace>(place), size, type));
} else if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
}
#else
holder_.reset(new PlaceholderImpl<platform::CUDAPlace>(
boost::get<platform::CUDAPlace>(place), size, type, is_pinned));
boost::get<platform::CUDAPlace>(place), size, type));
}
#endif
offset_ = 0;
is_pinned_ = is_pinned;
}
return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}

inline void* Tensor::mutable_data(platform::Place place, bool is_pinned) {
inline void* Tensor::mutable_data(platform::Place place) {
PADDLE_ENFORCE(this->holder_ != nullptr,
"Cannot invoke mutable data if current hold nothing");
return mutable_data(place, holder_->type(), is_pinned);
return mutable_data(place, holder_->type());
}

inline Tensor& Tensor::ShareDataWith(const Tensor& src) {
Expand Down Expand Up @@ -191,8 +188,6 @@ inline const DDim& Tensor::dims() const { return dims_; }

inline int64_t Tensor::numel() const { return product(dims_); }

inline bool Tensor::isPinned() const { return is_pinned_; }

inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
Tensor res;
res.ShareDataWith(src);
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/framework/tensor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ struct AnyVisitor : public boost::static_visitor<bool> {
const platform::CPUPlace& cpu) const {
return *out.data<bool>();
}

bool GetResult(const framework::Tensor& out,
const platform::CUDAPinnedPlace& cpu) const {
return *out.data<bool>();
}
};

template <typename Predicate>
Expand Down
20 changes: 12 additions & 8 deletions paddle/fluid/memory/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@ cc_library(memory SRCS memory.cc DEPS place enforce)
cc_library(memcpy SRCS memcpy.cc DEPS place)

cc_library(paddle_memory
DEPS
memory
memcpy
meta_data
meta_cache
memory_block
buddy_allocator
system_allocator)
DEPS
memory
memcpy
meta_data
meta_cache
memory_block
buddy_allocator
system_allocator)

cc_test(memory_test SRCS memory_test.cc DEPS place paddle_memory)

if (WITH_GPU)
nv_test(pinned_memory_test SRCS pinned_memory_test.cu DEPS place paddle_memory)
endif()
21 changes: 12 additions & 9 deletions paddle/fluid/memory/detail/system_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include "paddle/fluid/memory/detail/system_allocator.h"
#include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/gpu_info.h"

Expand Down Expand Up @@ -123,20 +124,22 @@ bool GPUAllocator::UseGpu() const { return true; }
// memory. It’s locked to a physical address.
void* CUDAPinnedAllocator::Alloc(size_t& index, size_t size) {
if (size <= 0) return nullptr;
void* p;
// NOTE: here, we use GpuMaxAllocSize() as the maximum memory size

// NOTE: here, we use CUDAPinnedMaxAllocSize as the maximum memory size
// of host pinned allocation. Allocates too much would reduce
// the amount of memory available to the underlying system for paging.

size_t usable = paddle::platform::GpuMaxAllocSize() - fallback_alloc_size_;
size_t usable =
Copy link
Contributor

@gongweibao gongweibao Apr 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is line 58 FLAGS_use_pinned_memory useful now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes

paddle::platform::CUDAPinnedMaxAllocSize() - cuda_pinnd_alloc_size_;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems default pinned memory max size is determined by system settings, can use ulimit -l to check out current system locked memory max size. I took a look at our current machines, the default value seems very small (64KB), to increase this setting, need to run ulimit -l [new size].probably it's better to add one document to describe this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review!
Searching around the internet, I found that

  • ulimit -l does affect the amount of memory we can memlock(), but it does not affect cudaMallocHost() because the CUDA pinning allocator doesn't use memlock.
  • the pinned allocator on CUDA under the hood is using mmap() with MAP_FIXED. The experiment is here.
  • So theoretically, the max size of pinned memory can be the max size of physical memory.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed information, quiet useful!


if (size > usable) return nullptr;

void* p;
// PINNED memory is visible to all CUDA contexts.
cudaError_t result = cudaMallocHost(&p, size);

if (result == cudaSuccess) {
index = 1;
fallback_alloc_size_ += size;
index = 1; // PINNED memory
cuda_pinnd_alloc_size_ += size;
return p;
}

Expand All @@ -147,8 +150,8 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) {
cudaError_t err;
PADDLE_ASSERT(index == 1);

PADDLE_ASSERT(fallback_alloc_size_ >= size);
fallback_alloc_size_ -= size;
PADDLE_ASSERT(cuda_pinnd_alloc_size_ >= size);
cuda_pinnd_alloc_size_ -= size;
err = cudaFreeHost(p);

// Purposefully allow cudaErrorCudartUnloading, because
Expand All @@ -161,7 +164,7 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) {
}
}

bool CUDAPinnedAllocator::UseGpu() const { return true; }
bool CUDAPinnedAllocator::UseGpu() const { return false; }

#endif

Expand Down
4 changes: 1 addition & 3 deletions paddle/fluid/memory/detail/system_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,7 @@ class CUDAPinnedAllocator : public SystemAllocator {
virtual bool UseGpu() const;

private:
size_t gpu_alloc_size_ =
0; // TODO(zcd): how to define the upper limit of CUDAPinnedMemory?
size_t fallback_alloc_size_ = 0;
size_t cuda_pinnd_alloc_size_ = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments at line 24 should be modified.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks!

};
#endif

Expand Down
39 changes: 39 additions & 0 deletions paddle/fluid/memory/memcpy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,45 @@ void Copy<platform::CUDAPlace, platform::CUDAPlace>(
}
}

template <>
void Copy<platform::CPUPlace, platform::CUDAPinnedPlace>(
platform::CPUPlace dst_place, void* dst,
platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CPUPlace>(
platform::CUDAPinnedPlace dst_place, void* dst,
platform::CPUPlace src_place, const void* src, size_t num) {
std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPinnedPlace>(
platform::CUDAPinnedPlace dst_place, void* dst,
platform::CUDAPinnedPlace src_place, const void* src, size_t num) {
std::memcpy(dst, src, num);
}

template <>
void Copy<platform::CUDAPinnedPlace, platform::CUDAPlace>(
platform::CUDAPinnedPlace dst_place, void* dst,
platform::CUDAPlace src_place, const void* src, size_t num,
cudaStream_t stream) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
}

template <>
void Copy<platform::CUDAPlace, platform::CUDAPinnedPlace>(
platform::CUDAPlace dst_place, void* dst,
platform::CUDAPinnedPlace src_place, const void* src, size_t num,
cudaStream_t stream) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
}

#endif

} // namespace memory
Expand Down
Loading