Skip to content

Commit

Permalink
The CUDA Async Allocator (#65092)
Browse files Browse the repository at this point in the history
* Async Pool and Memory Throttling

* fix rocm build

* fix flag

* fix rocm build

* fix flag
  • Loading branch information
eee4017 authored Jul 15, 2024
1 parent 101bf6e commit 8b808f1
Show file tree
Hide file tree
Showing 11 changed files with 472 additions and 180 deletions.
22 changes: 22 additions & 0 deletions paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1130,6 +1130,28 @@ PHI_DEFINE_EXPORTED_bool(use_cuda_malloc_async_allocator,
false,
"Enable CUDAMallocAsyncAllocator");

/*
* CUDAMallocAsyncAllocator related FLAG
* Name: FLAGS_cuda_malloc_async_pool_memory_throttle_ratio
* Since Version: 3.0
* Value Range: double, [0.0, 1.0], default=0.8
* Note:memory_throttle_ratio provides a threshold that determines when to
* initiate synchronization operations to deallocate memory. This mechanism
* helps in ensuring that the system does not exceed its memory capacity while
* also attempting to minimize performance degradation caused by frequent memory
* synchronization.
*
* Please see Note [cuda_malloc_async_pool_memory_throttle_ratio]
*/
PHI_DEFINE_EXPORTED_double(
cuda_malloc_async_pool_memory_throttle_ratio,
0.8,
"memory_throttle_ratio provides a threshold that determines when to "
"initiate synchronization operations to deallocate memory. "
"This mechanism helps in ensuring that the system does not exceed its "
"memory capacity while also attempting to minimize performance degradation "
"caused by frequent memory synchronization.");

/*
* CUDA Graph / Allocator related FLAG
* Name: FLAGS_auto_free_cudagraph_allocations_on_launch
Expand Down
82 changes: 55 additions & 27 deletions paddle/fluid/memory/allocation/allocator_facade.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,10 @@ class AllocatorFacadePrivate {
:
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
default_stream_safe_cuda_allocators_(),
default_cuda_malloc_async_allocators_(),
cuda_allocators_(),
#endif
#ifdef PADDLE_WITH_CUDA
default_cuda_malloc_async_allocators_(),
#endif
allocators_() {
strategy_ = GetAllocatorStrategy();
Expand Down Expand Up @@ -433,7 +435,7 @@ class AllocatorFacadePrivate {
/* unique_lock_guard */ {
std::unique_lock<std::shared_timed_mutex> lock_guard(
cuda_allocator_mutex_);
InitStreamSafeCUDAAllocator(place, stream);
InitCUDAAllocator(place, stream);
return cuda_allocators_[place][stream];
}
}
Expand All @@ -443,9 +445,11 @@ class AllocatorFacadePrivate {
if (auto iter = default_stream_safe_cuda_allocators_.find(place);
iter != default_stream_safe_cuda_allocators_.end())
return iter->second;
#ifdef PADDLE_WITH_CUDA
if (auto iter = default_cuda_malloc_async_allocators_.find(place);
iter != default_cuda_malloc_async_allocators_.end())
return iter->second;
#endif
PADDLE_THROW(platform::errors::NotFound(
"No StreamSafeCUDAAllocator found for the place, %s", place));
}
Expand All @@ -454,10 +458,12 @@ class AllocatorFacadePrivate {
if (auto allocator = std::dynamic_pointer_cast<StreamSafeCUDAAllocator>(
GetDefaultStreamSafeCUDAAllocator(place))) {
return allocator->GetDefaultStream();
#ifdef PADDLE_WITH_CUDA
} else if (auto allocator =
std::dynamic_pointer_cast<CUDAMallocAsyncAllocator>(
GetDefaultStreamSafeCUDAAllocator(place))) {
return allocator->GetDefaultStream();
#endif
} else {
PADDLE_THROW(platform::errors::NotFound(
"No StreamSafeCUDAAllocator or CUDAMallocAsyncAllocator found for "
Expand All @@ -484,6 +490,7 @@ class AllocatorFacadePrivate {
VLOG(8) << "Set default stream to " << stream
<< " for StreamSafeCUDAAllocator(" << allocator.get() << ") in "
<< place;
#ifdef PADDLE_WITH_CUDA
} else if (auto allocator =
std::dynamic_pointer_cast<CUDAMallocAsyncAllocator>(
GetDefaultStreamSafeCUDAAllocator(place))) {
Expand All @@ -501,6 +508,7 @@ class AllocatorFacadePrivate {
VLOG(8) << "Set default stream to " << stream
<< " for CUDAMallocAsyncAllocator(" << allocator.get() << ") in "
<< place;
#endif
} else {
PADDLE_THROW(platform::errors::NotFound(
"No StreamSafeCUDAAllocator or CUDAMallocAsyncAllocator found for "
Expand All @@ -511,45 +519,50 @@ class AllocatorFacadePrivate {

void RecordStream(std::shared_ptr<phi::Allocation> allocation,
gpuStream_t stream) {
if (auto cuda_malloc_async_allocation =
std::dynamic_pointer_cast<CUDAMallocAsyncAllocation>(allocation)) {
cuda_malloc_async_allocation->RecordStream(stream);
} else if (auto stream_safe_cuda_allocation =
std::dynamic_pointer_cast<StreamSafeCUDAAllocation>(
allocation)) {
if (auto stream_safe_cuda_allocation =
std::dynamic_pointer_cast<StreamSafeCUDAAllocation>(allocation)) {
stream_safe_cuda_allocation->RecordStream(stream);
#ifdef PADDLE_WITH_CUDA
} else if (auto cuda_malloc_async_allocation =
std::dynamic_pointer_cast<CUDAMallocAsyncAllocation>(
allocation)) {
cuda_malloc_async_allocation->RecordStream(stream);
#endif
} else {
VLOG(6) << "RecordStream for a non-StreamSafeCUDAAllocation";
}
}

void EraseStream(std::shared_ptr<phi::Allocation> allocation,
gpuStream_t stream) {
if (auto cuda_malloc_async_allocation =
std::dynamic_pointer_cast<CUDAMallocAsyncAllocation>(allocation)) {
cuda_malloc_async_allocation->EraseStream(stream);
} else if (auto stream_safe_cuda_allocation =
std::dynamic_pointer_cast<StreamSafeCUDAAllocation>(
allocation)) {
if (auto stream_safe_cuda_allocation =
std::dynamic_pointer_cast<StreamSafeCUDAAllocation>(allocation)) {
stream_safe_cuda_allocation->EraseStream(stream);
#ifdef PADDLE_WITH_CUDA
} else if (auto cuda_malloc_async_allocation =
std::dynamic_pointer_cast<CUDAMallocAsyncAllocation>(
allocation)) {
cuda_malloc_async_allocation->EraseStream(stream);
#endif
} else {
VLOG(6) << "EraseStream for a non-StreamSafeCUDAAllocation";
}
}

gpuStream_t GetStream(
const std::shared_ptr<phi::Allocation>& allocation) const {
if (const std::shared_ptr<CUDAMallocAsyncAllocation>
cuda_malloc_async_allocation =
std::dynamic_pointer_cast<CUDAMallocAsyncAllocation>(
if (const std::shared_ptr<StreamSafeCUDAAllocation>
stream_safe_cuda_allocation =
std::dynamic_pointer_cast<StreamSafeCUDAAllocation>(
allocation)) {
return cuda_malloc_async_allocation->GetOwningStream();

} else if (const std::shared_ptr<StreamSafeCUDAAllocation>
stream_safe_cuda_allocation =
std::dynamic_pointer_cast<StreamSafeCUDAAllocation>(
allocation)) {
return stream_safe_cuda_allocation->GetOwningStream();
#ifdef PADDLE_WITH_CUDA
} else if (const std::shared_ptr<CUDAMallocAsyncAllocation>
cuda_malloc_async_allocation =
std::dynamic_pointer_cast<CUDAMallocAsyncAllocation>(
allocation)) {
return cuda_malloc_async_allocation->GetOwningStream();
#endif
}

VLOG(6) << "GetStream for a non-StreamSafeCUDAAllocation";
Expand Down Expand Up @@ -865,7 +878,7 @@ class AllocatorFacadePrivate {
return std::make_shared<CUDAAllocator>(p);
}

void InitStreamSafeCUDAAllocator(phi::GPUPlace p, gpuStream_t stream) {
void InitCUDAAllocator(phi::GPUPlace p, gpuStream_t stream) {
PADDLE_ENFORCE_EQ(
strategy_,
AllocatorStrategy::kAutoGrowth,
Expand Down Expand Up @@ -897,9 +910,14 @@ class AllocatorFacadePrivate {
}

void InitCUDAMallocAsyncAllocator(phi::GPUPlace p, gpuStream_t stream) {
#ifdef PADDLE_WITH_CUDA
std::shared_ptr<Allocator>& allocator = cuda_allocators_[p][stream];
cuda_allocators_[p][stream] =
std::make_shared<CUDAMallocAsyncAllocator>(allocator, p, stream);
#else
PADDLE_THROW(platform::errors::Unavailable(
"CUDAMallocAsyncAllocator is not enabled"));
#endif
}

void InitAutoGrowthCUDAAllocator(phi::GPUPlace p, gpuStream_t stream) {
Expand Down Expand Up @@ -1169,6 +1187,7 @@ class AllocatorFacadePrivate {
}

void WrapCUDAMallocAsyncAllocatorForDefault() {
#ifdef PADDLE_WITH_CUDA
for (auto& pair : allocators_) {
auto& place = pair.first;
if (phi::is_gpu_place(place)) {
Expand All @@ -1188,6 +1207,10 @@ class AllocatorFacadePrivate {
<< ", allocator address = " << pair.second.get();
}
}
#else
PADDLE_THROW(platform::errors::Unavailable(
"CUDAMallocAsyncAllocator is not enabled"));
#endif
}

void WrapCUDARetryAllocator(phi::GPUPlace p,
Expand Down Expand Up @@ -1547,12 +1570,15 @@ class AllocatorFacadePrivate {
// a standalone CUDA allocator to support multi-stream GC in new executor
std::map<phi::Place, std::shared_ptr<StreamSafeCUDAAllocator>>
default_stream_safe_cuda_allocators_;
std::map<phi::Place, std::shared_ptr<CUDAMallocAsyncAllocator>>
default_cuda_malloc_async_allocators_;
CUDAAllocatorMap cuda_allocators_;
std::shared_timed_mutex cuda_allocator_mutex_;
#endif

#if defined(PADDLE_WITH_CUDA)
std::map<platform::Place, std::shared_ptr<CUDAMallocAsyncAllocator>>
default_cuda_malloc_async_allocators_;
#endif

#ifdef PADDLE_WITH_XPU
// a standalone XPU allocator to support multi-stream GC in new executor
std::map<phi::Place, std::shared_ptr<StreamSafeXPUAllocator>>
Expand Down Expand Up @@ -1809,14 +1835,16 @@ void AllocatorFacade::PrepareMemoryPoolForCUDAGraph(int64_t id) {
FLAGS_allocator_strategy));
auto& allocator = cuda_graph_map_[id];
auto& ref_cnt = cuda_graph_ref_cnt_[id];
++ref_cnt;

if (FLAGS_use_cuda_malloc_async_allocator) return;
if (allocator.get() == nullptr) {
allocator = std::make_unique<AllocatorFacadePrivate>(
/*allow_free_idle_chunk=*/false);
VLOG(10) << "Create memory pool for CUDA Graph with memory ID " << id;
} else {
VLOG(10) << "Use created memory pool for CUDA Graph with memory ID " << id;
}
++ref_cnt;
}

void AllocatorFacade::RemoveMemoryPoolOfCUDAGraph(int64_t id) {
Expand Down
Loading

0 comments on commit 8b808f1

Please sign in to comment.