Skip to content

Commit

Permalink
add rocm malloc aysnc
Browse files Browse the repository at this point in the history
  • Loading branch information
weihanmines committed Jan 11, 2024
1 parent e7a2ef5 commit d7df941
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 56 deletions.
7 changes: 5 additions & 2 deletions xla/stream_executor/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,9 @@ tsl_gpu_library(
srcs = [
"gpu_cudamallocasync_allocator.cc",
],
hdrs = ["gpu_cudamallocasync_allocator.h"],
hdrs = ["gpu_cudamallocasync_allocator.h",
"gpu_types.h",
],
cuda_deps = [
"//xla/stream_executor/cuda:cuda_activation",
"//xla/stream_executor/cuda:cuda_executor",
Expand All @@ -545,7 +547,8 @@ tsl_gpu_library(
"@tsl//tsl/platform:macros",
"@tsl//tsl/platform:mutex",
"@tsl//tsl/util:env_var",
],
] + if_rocm_is_configured([
"//xla/stream_executor/rocm:rocm_activation"]),
)

cc_library(
Expand Down
171 changes: 121 additions & 50 deletions xla/stream_executor/gpu/gpu_cudamallocasync_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,62 @@ limitations under the License.
#include <vector>

#ifdef GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"
#define GpuGetErrorString cuGetErrorString
#define GpuGetErrorName cuGetErrorName
#define GpuDriverGetVersion cuDriverGetVersion
#define GpuDevicePrimaryCtxRetain cuDevicePrimaryCtxRetain
#define GpuMemPoolGetAttribute cuMemPoolGetAttribute
#define GpuMemPoolSetAttribute cuMemPoolSetAttribute
#define GPU_MEMPOOL_ATTR_RELEASE_THRESHOLD CU_MEMPOOL_ATTR_RELEASE_THRESHOLD
#define GpuMemFreeAsync cuMemFreeAsync
#define GPU_ERROR_DEINITIALIZED CUDA_ERROR_DEINITIALIZED
#define GpuMemGetInfo cuMemGetInfo
#define GpuMemAllocFromPoolAsync cuMemAllocFromPoolAsync
#define GpuStreamSynchronize cuStreamSynchronize
#define GPU_ERROR_OUT_OF_MEMORY CUDA_ERROR_OUT_OF_MEMORY
#define GpuDeviceGetAttribute cuDeviceGetAttribute
#define GPU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED
#define GPU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC
#define GPU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT
#define GPU_MEMPOOL_ATTR_USED_MEM_CURRENT CU_MEMPOOL_ATTR_USED_MEM_CURRENT
#define GPU_MEMPOOL_ATTR_RESERVED_MEM_HIGH CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH
#define GPU_MEMPOOL_ATTR_USED_MEM_HIGH CU_MEMPOOL_ATTR_USED_MEM_HIGH
#define GPU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES
#define GPU_MEM_ACCESS_FLAGS_PROT_READWRITE CU_MEM_ACCESS_FLAGS_PROT_READWRITE
#define GPU_MEM_LOCATION_TYPE_DEVICE CU_MEM_LOCATION_TYPE_DEVICE
#define GpuDeviceGetDefaultMemPool cuDeviceGetDefaultMemPool
#define GpuDeviceCanAccessPeer cuDeviceCanAccessPeer
#define GpuMemPoolSetAccess cuMemPoolSetAccess
#include "xla/stream_executor/cuda/cuda_activation.h"
#elif TENSORFLOW_USE_ROCM
using cuuint64_t = uint64_t;
#define GpuGetErrorString hipGetErrorString
#define GpuGetErrorName hipGetErrorName
#define GpuDriverGetVersion hipDriverGetVersion
#define GpuDevicePrimaryCtxRetain hipDevicePrimaryCtxRetain
#define GpuMemPoolGetAttribute hipMemPoolGetAttribute
#define GpuMemPoolSetAttribute hipMemPoolSetAttribute
#define GPU_MEMPOOL_ATTR_RELEASE_THRESHOLD hipMemPoolAttrReleaseThreshold
#define GpuMemFreeAsync hipFreeAsync
#define GPU_ERROR_DEINITIALIZED hipErrorDeinitialized
#define GpuMemGetInfo hipMemGetInfo
#define GpuMemAllocFromPoolAsync hipMallocFromPoolAsync
#define GpuStreamSynchronize hipStreamSynchronize
#define GPU_ERROR_OUT_OF_MEMORY hipErrorOutOfMemory
#define GpuDeviceGetAttribute hipDeviceGetAttribute
#define GPU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED hipDeviceAttributeMemoryPoolsSupported
#define GPU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC hipMemPoolReuseAllowOpportunistic
#define GPU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT hipMemPoolAttrReservedMemCurrent
#define GPU_MEMPOOL_ATTR_USED_MEM_CURRENT hipMemPoolAttrUsedMemCurrent
#define GPU_MEMPOOL_ATTR_RESERVED_MEM_HIGH hipMemPoolAttrReservedMemHigh
#define GPU_MEMPOOL_ATTR_USED_MEM_HIGH hipMemPoolAttrUsedMemHigh
#define GPU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES hipMemPoolReuseAllowInternalDependencies
#define GPU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define GPU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define GpuDeviceGetDefaultMemPool hipDeviceGetDefaultMemPool
#define GpuDeviceCanAccessPeer hipDeviceCanAccessPeer
#define GpuMemPoolSetAccess hipMemPoolSetAccess
#include "xla/stream_executor/rocm/rocm_activation.h"
#endif // GOOGLE_CUDA

#include "absl/strings/str_cat.h"
Expand All @@ -38,12 +92,17 @@ limitations under the License.

namespace stream_executor {

#if GOOGLE_CUDA
static std::string GetCudaErrorMessage(CUresult result) {
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
static std::string GetCudaErrorMessage(gpu::GpuStatus result) {
const char* error;
cuGetErrorString(result, &error);
const char* name;
cuGetErrorName(result, &name);
#if GOOGLE_CUDA
GpuGetErrorString(result, &error);
GpuGetErrorName(result, &name);
#elif TENSORFLOW_USE_ROCM
error = GpuGetErrorString(result);
name = GpuGetErrorName(result);
#endif
return absl::StrCat("CUDA error: ", error ? error : "<unknown>", " (",
name ? name : "Unknown", ")");
}
Expand All @@ -68,27 +127,27 @@ void GpuCudaMallocAsyncAllocator::PrintAllocatorStatisticsNoLock() {
VLOG(8) << "\nThe sorted list of (ptr,size):";
VLOG(8) << absl::StrJoin(ptr_size_string, ",");

#if CUDA_VERSION >= 11030
#if CUDA_VERSION >= 11030 || TF_ROCM_VERSION >= 50300
cuuint64_t mem_reserved_current;
if (auto result = cuMemPoolGetAttribute(
pool_, CU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT, &mem_reserved_current)) {
if (auto result = GpuMemPoolGetAttribute(
pool_, GPU_MEMPOOL_ATTR_RESERVED_MEM_CURRENT, &mem_reserved_current)) {
LOG(ERROR) << "Error while fetching extra cudaMallocAsync pool attribute: "
<< GetCudaErrorMessage(result);
}
cuuint64_t mem_used_current;
if (auto result = cuMemPoolGetAttribute(
pool_, CU_MEMPOOL_ATTR_USED_MEM_CURRENT, &mem_used_current)) {
if (auto result = GpuMemPoolGetAttribute(
pool_, GPU_MEMPOOL_ATTR_USED_MEM_CURRENT, &mem_used_current)) {
LOG(ERROR) << "Error while fetching extra cudaMallocAsync pool attribute: "
<< GetCudaErrorMessage(result);
}
cuuint64_t mem_reserved_high;
if (auto result = cuMemPoolGetAttribute(
pool_, CU_MEMPOOL_ATTR_RESERVED_MEM_HIGH, &mem_reserved_high)) {
if (auto result = GpuMemPoolGetAttribute(
pool_, GPU_MEMPOOL_ATTR_RESERVED_MEM_HIGH, &mem_reserved_high)) {
LOG(ERROR) << "Error while fetching extra cudaMallocAsync pool attribute: "
<< GetCudaErrorMessage(result);
}
cuuint64_t mem_used_high;
if (auto result = cuMemPoolGetAttribute(pool_, CU_MEMPOOL_ATTR_USED_MEM_HIGH,
if (auto result = GpuMemPoolGetAttribute(pool_, GPU_MEMPOOL_ATTR_USED_MEM_HIGH,
&mem_used_high)) {
LOG(ERROR) << "Error while fetching extra cudaMallocAsync pool attribute: "
<< GetCudaErrorMessage(result);
Expand Down Expand Up @@ -123,9 +182,14 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(
pool_ = nullptr;
cuda_stream_ = nullptr;
int driverVersion;
cuDriverGetVersion(&driverVersion);
GpuDriverGetVersion(&driverVersion);
VLOG(2) << "DRIVER VERSION: " << driverVersion;
if (driverVersion < 11020) {
#if GOOGLE_CUDA
if (driverVersion < 11020)
#elif TENSORFLOW_USE_ROCM
if (driverVersion < 50300)
#endif
{
LOG(FATAL) // Crash OK.
<< "Disable cuda_malloc_async or update your CUDA driver to a version"
<< " compatible with CUDA 11.2 or higher."
Expand All @@ -135,17 +199,24 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(
// WAR an CUDA 11.2 driver bug for multiple-GPU. It currently
// request that the context on GPU 0 is initialized. Which isn't the
// case for TF+horovod.
if (platform_device_id.value() > 0 && driverVersion < 11030) {
CUcontext pctx; // We loose track of it. But this is fine.
if (auto result = cuDevicePrimaryCtxRetain(&pctx, 0))
if (platform_device_id.value() > 0) {
#if GOOGLE_CUDA
if (driverVersion < 11020)
#elif TENSORFLOW_USE_ROCM
if (driverVersion < 50300)
#endif
{
GpuContextHandle pctx; // We loose track of it. But this is fine.
if (auto result = GpuDevicePrimaryCtxRetain(&pctx, 0))
LOG(FATAL) // Crash OK.
<< "Failed to retain context: " << GetCudaErrorMessage(result);
}
}

cuda::ScopedActivateExecutorContext scoped_activation{stream_exec_};
gpu::ScopedActivateExecutorContext scoped_activation{stream_exec_};

// Check the CUDA runtime is recent enough.
if (auto status2 = cuDriverGetVersion(&driverVersion)) {
if (auto status2 = GpuDriverGetVersion(&driverVersion)) {
LOG(FATAL) // Crash OK.
<< "Error while fetching driver version: "
<< GetCudaErrorMessage(status2);
Expand All @@ -154,8 +225,8 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(
// Check that cudaMallocAsync is supported.
int cuda_malloc_async_supported;
if (auto status =
cuDeviceGetAttribute(&cuda_malloc_async_supported,
CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
GpuDeviceGetAttribute(&cuda_malloc_async_supported,
GPU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED,
platform_device_id.value())) {
LOG(FATAL) // Crash OK.
<< "On device: " << platform_device_id.value()
Expand All @@ -171,16 +242,16 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(
<< " OS not supported, CUDA version too old(request CUDA11.2+).";

if (auto status =
cuDeviceGetDefaultMemPool(&pool_, platform_device_id.value()))
GpuDeviceGetDefaultMemPool(&pool_, platform_device_id.value()))
LOG(FATAL) << // Crash OK.
"Failed to get default CUDA pool: " << GetCudaErrorMessage(status);

VLOG(1) << Name() << " CudaMallocAsync initialized on platform: "
<< platform_device_id.value() << " with pool size of: " << pool_size
<< " this ptr: " << this;
uint64_t pool_size_64 = pool_size;
if (auto status = cuMemPoolSetAttribute(
pool_, CU_MEMPOOL_ATTR_RELEASE_THRESHOLD, &pool_size_64))
if (auto status = GpuMemPoolSetAttribute(
pool_, GPU_MEMPOOL_ATTR_RELEASE_THRESHOLD, &pool_size_64))
LOG(FATAL) << // Crash OK.
"Failed to set CUDA pool attribute: " << GetCudaErrorMessage(status);

Expand All @@ -196,34 +267,34 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(
/*default_val=*/false, &deterministic));
if (deterministic) {
int disable = 0;
if (auto status = cuMemPoolSetAttribute(
pool_, CU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC, &disable)) {
if (auto status = GpuMemPoolSetAttribute(
pool_, GPU_MEMPOOL_ATTR_REUSE_ALLOW_OPPORTUNISTIC, &disable)) {
LOG(FATAL) << // Crash OK.
"Failed to set CUDA pool attribute: " << GetCudaErrorMessage(status);
}
if (auto status = cuMemPoolSetAttribute(
pool_, CU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES,
if (auto status = GpuMemPoolSetAttribute(
pool_, GPU_MEMPOOL_ATTR_REUSE_ALLOW_INTERNAL_DEPENDENCIES,
&disable)) {
LOG(FATAL) << // Crash OK.
"Failed to set CUDA pool attribute: " << GetCudaErrorMessage(status);
}
}

// Set read/write access to all GPUs.
static auto* all_pools_ = new std::vector<CUmemoryPool*>();
static auto* all_pools_ = new std::vector<GpuMemoryPoolHandle*>();
static auto* all_ids_ = new std::vector<tsl::PlatformDeviceId>();
DCHECK(all_pools_->size() == all_ids_->size());
for (int i = 0; i < all_pools_->size(); ++i) {
// Set the current pool access to the previous GPUs.
CUmemAccessDesc map;
map.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
gpu::GpuMemAccessDesc map;
map.flags = GPU_MEM_ACCESS_FLAGS_PROT_READWRITE;
map.location.id = (*all_ids_)[i].value();

map.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
map.location.type = GPU_MEM_LOCATION_TYPE_DEVICE;
VLOG(2) << "Setting access of the current pool to "
<< " location id: " << map.location.id;
int canAccessPeer;
if (auto status = cuDeviceCanAccessPeer(
if (auto status = GpuDeviceCanAccessPeer(
&canAccessPeer, platform_device_id.value(), map.location.id)) {
pool_ = nullptr;
LOG(FATAL) // Crash OK.
Expand All @@ -232,7 +303,7 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(
<< platform_device_id.value() << ": " << GetCudaErrorMessage(status);
}
if (canAccessPeer == 1) {
if (auto status = cuMemPoolSetAccess(pool_, &map, 1)) {
if (auto status = GpuMemPoolSetAccess(pool_, &map, 1)) {
pool_ = nullptr;
LOG(FATAL) // Crash OK.
<< "Error when setting access to the pool id: " << i
Expand All @@ -246,14 +317,14 @@ GpuCudaMallocAsyncAllocator::GpuCudaMallocAsyncAllocator(

VLOG(2) << "Set access to the pool id: " << i
<< " location id: " << map.location.id;
if (auto status = cuDeviceCanAccessPeer(&canAccessPeer, i,
if (auto status = GpuDeviceCanAccessPeer(&canAccessPeer, i,
platform_device_id.value())) {
pool_ = nullptr;
LOG(FATAL) // Crash OK.
<< "cuDeviceCanAccessPeer failed: " << GetCudaErrorMessage(status);
}
if (canAccessPeer == 1) {
if (auto status = cuMemPoolSetAccess(*(*all_pools_)[i], &map, 1)) {
if (auto status = GpuMemPoolSetAccess(*(*all_pools_)[i], &map, 1)) {
pool_ = nullptr;
LOG(FATAL) // Crash OK.
<< "Error when setting access to the pool id: " << i
Expand Down Expand Up @@ -290,21 +361,21 @@ void* GpuCudaMallocAsyncAllocator::AllocateRaw(size_t alignment,
if (stats_) {
lock.lock();
}
cuda::ScopedActivateExecutorContext scoped_activation{stream_exec_};
gpu::ScopedActivateExecutorContext scoped_activation{stream_exec_};
void* ptr = nullptr;
auto result = cuMemAllocFromPoolAsync(reinterpret_cast<CUdeviceptr*>(&ptr),
auto result = GpuMemAllocFromPoolAsync(reinterpret_cast<GpuDevicePtr*>(&ptr),
num_bytes, pool_, cuda_stream_);
if (result == CUDA_ERROR_OUT_OF_MEMORY) {
if (result == GPU_ERROR_OUT_OF_MEMORY) {
// Doing a stream synchronization give the driver more flexibility
// for blocks coalescing and doing memory remapping. So it can
// solve some OOM cases when memory is tight.
cuStreamSynchronize(cuda_stream_);
result = cuMemAllocFromPoolAsync(reinterpret_cast<CUdeviceptr*>(&ptr),
GpuStreamSynchronize(cuda_stream_);
result = GpuMemAllocFromPoolAsync(reinterpret_cast<GpuDevicePtr*>(&ptr),
num_bytes, pool_, cuda_stream_);
}
if (result) {
size_t free, total;
cuMemGetInfo(&free, &total);
GpuMemGetInfo(&free, &total);
LOG(ERROR) << Name() << " cuMemAllocAsync failed to allocate " << num_bytes
<< " bytes: " << GetCudaErrorMessage(result)
<< "\n Reported by CUDA: Free memory/Total memory: " << free
Expand Down Expand Up @@ -347,17 +418,17 @@ void GpuCudaMallocAsyncAllocator::DeallocateRaw(void* ptr) {
if (stats_) {
lock.lock();
}
if (auto result = cuMemFreeAsync(reinterpret_cast<const CUdeviceptr&>(ptr),
if (auto result = GpuMemFreeAsync(reinterpret_cast<const GpuDevicePtr&>(ptr),
cuda_stream_)) {
if (result == CUDA_ERROR_DEINITIALIZED) {
if (result == GPU_ERROR_DEINITIALIZED) {
// It happens with multi-GPU that TF free the GPU allocation after
// the driver is unloaded. It is safe to ignore this error here.
// TODO: Find how to fix the shutdown steps in TF.
VLOG(1) << "Ignoring CUDA error: " << GetCudaErrorMessage(result);
} else {
size_t free, total;
cuda::ScopedActivateExecutorContext scoped_activation{stream_exec_};
cuMemGetInfo(&free, &total);
gpu::ScopedActivateExecutorContext scoped_activation{stream_exec_};
GpuMemGetInfo(&free, &total);
LOG(ERROR) << "cudaFreeAsync failed to free " << ptr << ": "
<< GetCudaErrorMessage(result)
<< "\n Free memory/Total memory: " << free << "/" << total;
Expand Down Expand Up @@ -412,16 +483,16 @@ bool GpuCudaMallocAsyncAllocator::ClearStats() {

void GpuCudaMallocAsyncAllocator::SetStreamAndPreallocateMemory(void* stream) {
#if TF_CUDA_MALLOC_ASYNC_SUPPORTED
auto new_cuda_stream = static_cast<CUstream>(stream);
auto new_cuda_stream = static_cast<GpuStreamHandle>(stream);
// We don't need to re-set the CUDA stream if this is the same stream
if (cuda_stream_ != nullptr && new_cuda_stream != cuda_stream_) {
LOG(FATAL) << // Crash OK.
"Trying to set the stream twice. This isn't supported. ";
}

uint64_t pool_size_64 = 0;
if (auto status = cuMemPoolGetAttribute(
pool_, CU_MEMPOOL_ATTR_RELEASE_THRESHOLD, &pool_size_64)) {
if (auto status = GpuMemPoolGetAttribute(
pool_, GPU_MEMPOOL_ATTR_RELEASE_THRESHOLD, &pool_size_64)) {
LOG(FATAL) << // Crash OK.
"Failed to get CUDA pool attribute: " << GetCudaErrorMessage(status);
}
Expand Down
9 changes: 5 additions & 4 deletions xla/stream_executor/gpu/gpu_cudamallocasync_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,12 @@ limitations under the License.
#include "tsl/framework/device_id.h"
#include "tsl/platform/macros.h"
#include "tsl/platform/mutex.h"
#include "xla/stream_executor/gpu/gpu_types.h"

#if GOOGLE_CUDA
#include "third_party/gpus/cuda/include/cuda.h"

#define TF_CUDA_MALLOC_ASYNC_SUPPORTED CUDA_VERSION >= 11020
#elif TENSORFLOW_USE_ROCM
#define TF_CUDA_MALLOC_ASYNC_SUPPORTED TF_ROCM_VERSION >= 50300
#endif // GOOGLE_CUDA


Expand Down Expand Up @@ -105,12 +106,12 @@ class GpuCudaMallocAsyncAllocator : public tsl::Allocator {
// stream. So we do not need to ask cudaMallocAsync to add extra
// synchronization.
// Not owned.
CUstream cuda_stream_;
GpuStreamHandle cuda_stream_;

// Not owned. The default pool of the associated GPU.
// If null, then the instanciation failed and the first allocation
// will return an error.
CUmemoryPool pool_;
GpuMemoryPoolHandle pool_;
#endif // TF_CUDA_MALLOC_ASYNC_SUPPORTED

// Just a counter for the number of time this class is instantiated.
Expand Down
Loading

0 comments on commit d7df941

Please sign in to comment.