Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplize system_allocator and fix GPU_INFO #6653

Merged
merged 1 commit into from
Dec 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 13 additions & 37 deletions paddle/memory/detail/system_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */

#include <stdlib.h> // for malloc and free
#include <sys/mman.h> // for mlock and munlock
#include <algorithm> // for std::max

#include "gflags/gflags.h"

Expand All @@ -28,7 +29,7 @@ limitations under the License. */
// of memory available to the system for paging. So, by default, we
// should set false to use_pinned_memory.
DEFINE_bool(use_pinned_memory, true, "If set, allocate cpu pinned memory.");

DECLARE_double(fraction_of_gpu_memory_to_use);
namespace paddle {
namespace memory {
namespace detail {
Expand Down Expand Up @@ -77,45 +78,20 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) {
// CUDA documentation doesn't explain if cudaMalloc returns nullptr
// if size is 0. We just make sure it does.
if (size <= 0) return nullptr;

size_t available = 0;
size_t capacity = 0;
paddle::platform::GpuMemoryUsage(available, capacity);

// Reserve memory for page tables, etc.
size_t reserving = 0.05 * capacity + paddle::platform::GpuMinChunkSize();
size_t usable = available > reserving ? available - reserving : 0;

// If remaining size no less than expected size, using general
// cudaMalloc to allocate GPU memory.
void* p = 0;
if (size <= usable) {
cudaError_t result = cudaMalloc(&p, size);
if (result == cudaSuccess) {
index = 0;
gpu_alloc_size_ += size;
return p;
}
}

// If remaining size less than expected size or cudaMalloc failed,
// cudaMallocHost will be considered as a fallback allocator.
//
// NOTE: here, we use GpuMaxAllocSize() as the maximum memory size
// of host fallback allocation. Allocates too much would reduce
// the amount of memory available to the underlying system for paging.
usable = paddle::platform::GpuMaxAllocSize() - fallback_alloc_size_;

if (size > usable) return nullptr;

cudaError_t result = cudaMallocHost(&p, size);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove cudaMallocHost?

Most of the memory required for training deep recurrent networks is used to store activations through each layer for use by back propagation, not to store the parameters of the network. For example, storing the weights for a 70M parameter network with 9 layers requires approximately 280 MB of memory, but storing the activations for a batch of 64, seven-second utterances requires 1.5 GB of memory. TitanX GPUs include 12GB of GDDR5 RAM, and sometimes very deep networks can exceed the GPU memory capacity when processing long utterances. This can happen unpredictably, it is desirable to avoid a catastrophic failure when this occurs.

The combination of fast memory allocation with a fallback mechanism that allows us to slightly overflow available GPU memory in exceptional cases makes the system significantly simpler, more robust, and more efficient.

This memory can be accessed directly by the GPU by forwarding individual memory transactions over PCIe at reduced bandwidth, and it allows a model to continue to make progress even after encountering an outlier.

@reyoung @chengduoZH @wangkuiyi

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deep Speech 2 section 4.3, https://arxiv.org/pdf/1512.02595.pdf

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From Zen of python

Explicit is better than implicit.
Simple is better than complex.

cudaMallocHost will be invoked IMPLICITLY when out of memory in the previous implementation. The performance is very poor when uses cudaMallocHost to allocate memory and run the kernel on GPU and since it is an implicit behaviour, it is hard to debug before.

It may be better to fail fast when out of memory. It is explicit and simple.

If this feature is needed, we can also implement another decorator rather than combine these logics together. For example:

class CUDAFallbackAllocator {
public:
  void* alloc(size_t size) {
    void* ptr = allocator_->alloc(size);
    if (ptr == nullptr) {
        return cudaMallocHost(size);
    } else {
       return ptr;
    }
  }

private:
  CUDASystemAllocator* allocator_;
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks.

void* p;
cudaError_t result = cudaMalloc(&p, size);
if (result == cudaSuccess) {
index = 1;
fallback_alloc_size_ += size;
index = 0;
gpu_alloc_size_ += size;
return p;
} else {
LOG(WARNING)
<< "Cannot malloc " << size / 1024.0 / 1024.0
<< " MB GPU memory. Please shrink FLAGS_fraction_of_gpu_memory_to_use "
"environment variable to a lower value. Current value is "
<< FLAGS_fraction_of_gpu_memory_to_use;
return nullptr;
}

return nullptr;
}

void GPUAllocator::Free(void* p, size_t size, size_t index) {
Expand Down
19 changes: 10 additions & 9 deletions paddle/platform/gpu_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,19 +73,20 @@ size_t GpuMaxChunkSize() {
size_t available = 0;

GpuMemoryUsage(available, total);

// Reserving the rest memory for page tables, etc.
size_t reserving = 0.05 * total;

VLOG(10) << "GPU Usage " << available / 1024 / 1024 << "M/"
<< total / 1024 / 1024 << "M";
size_t reserving = static_cast<size_t>(0.05 * total);
// If available less than minimum chunk size, no usable memory exists.
available =
std::max(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(),
reserving) -
reserving;
std::min(std::max(available, GpuMinChunkSize()) - GpuMinChunkSize(),
total - reserving);

// Reserving the rest memory for page tables, etc.

size_t allocating = FLAGS_fraction_of_gpu_memory_to_use * total;
size_t allocating = static_cast<size_t>(FLAGS_fraction_of_gpu_memory_to_use *
(total - reserving));

PADDLE_ENFORCE_LT(allocating, available);
PADDLE_ENFORCE_LE(allocating, available);

return allocating;
}
Expand Down