Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI decouple] move dropout_impl and cuda_graph_with_memory_pool from fluid to phi #49139

Merged
merged 12 commits into from
Dec 20, 2022
6 changes: 3 additions & 3 deletions paddle/fluid/memory/allocation/allocator_facade.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"

#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
#endif

#if CUDA_VERSION >= 10020
Expand Down Expand Up @@ -157,7 +157,7 @@ class CUDAGraphAllocator

static bool IsCUDAGraphCapturing() {
#ifdef PADDLE_WITH_CUDA
return UNLIKELY(platform::CUDAGraph::IsThisThreadCapturing());
return UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing());
#else
return false;
#endif
Expand Down Expand Up @@ -1007,7 +1007,7 @@ AllocatorFacade& AllocatorFacade::Instance() {
AllocatorFacadePrivate* AllocatorFacade::GetPrivate() const {
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(IsCUDAGraphCapturing())) {
auto id = platform::CUDAGraph::CapturingPoolID();
auto id = phi::backends::gpu::CUDAGraph::CapturingPoolID();
auto iter = cuda_graph_map_.find(id);
PADDLE_ENFORCE_NE(
iter,
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/memory/allocation/stream_safe_cuda_allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "paddle/phi/backends/gpu/gpu_info.h"

#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
#endif

namespace paddle {
Expand Down Expand Up @@ -49,7 +49,7 @@ void StreamSafeCUDAAllocation::RecordStream(gpuStream_t stream) {

std::lock_guard<SpinLock> lock_guard(outstanding_event_map_lock_);
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(platform::CUDAGraph::IsThisThreadCapturing())) {
if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) {
graph_capturing_stream_set_.insert(stream);
return;
}
Expand All @@ -61,7 +61,7 @@ void StreamSafeCUDAAllocation::RecordStream(gpuStream_t stream) {

bool StreamSafeCUDAAllocation::CanBeFreed() {
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(platform::CUDAGraph::IsThisThreadCapturing())) {
if (UNLIKELY(phi::backends::gpu::CUDAGraph::IsThisThreadCapturing())) {
return graph_capturing_stream_set_.empty() &&
outstanding_event_map_.empty();
}
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/memory/stream_safe_cuda_alloc_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ class StreamSafeCUDAAllocTest : public ::testing::Test {
data, result, data_num_);
RecordStream(data_allocation, other_stream);

std::unique_ptr<platform::CUDAGraph> cuda_graph =
std::unique_ptr<phi::backends::gpu::CUDAGraph> cuda_graph =
platform::EndCUDAGraphCapture();

int replay_times = 10;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/cuda_graph_with_in_out.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class CUDAGraphWithInOuts {
int64_t PoolID() const { return graph_->PoolID(); }

private:
std::unique_ptr<platform::CUDAGraph> graph_;
std::unique_ptr<phi::backends::gpu::CUDAGraph> graph_;
std::vector<phi::DenseTensor> ins_;
std::vector<phi::DenseTensor> outs_;
std::vector<int64_t> in_indices_;
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/fused/fmha_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ limitations under the License. */

#pragma once

#include "paddle/fluid/operators/dropout_impl.cu.h"
#include "paddle/fluid/operators/fused/fused_softmax_mask.cu.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/dropout_impl.cu.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h"
Expand Down Expand Up @@ -206,7 +206,7 @@ class FMHARef {
stride_b = gemm_k * gemm_n;

if (dropout_param_.dropout_prob_) {
DropoutFwGPUKernelDriver<T>(
phi::funcs::DropoutFwGPUKernelDriver<T>(
static_cast<const phi::GPUContext&>(dev_ctx_),
dropout_param_.is_test_,
dropout_param_.dropout_prob_,
Expand Down Expand Up @@ -381,7 +381,7 @@ class FMHARef {
stride_b = gemm_k * gemm_n;

if (dropout_param_.dropout_prob_) {
DropoutFwGPUKernelDriver<T>(
phi::funcs::DropoutFwGPUKernelDriver<T>(
static_cast<const phi::GPUContext&>(dev_ctx_),
dropout_param_.is_test_,
dropout_param_.dropout_prob_,
Expand Down Expand Up @@ -552,7 +552,7 @@ class FMHARef {
}
// dropout bw
if (dropout_param_.dropout_prob_) {
DropoutGradGPUKernelDriver<T>(
phi::funcs::DropoutGradGPUKernelDriver<T>(
static_cast<const phi::GPUContext&>(dev_ctx_),
false,
dropout_param_.dropout_prob_,
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fused/fused_dropout_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ limitations under the License. */
#pragma once

#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/dropout_impl_util.h"
#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
#include "paddle/phi/kernels/funcs/dropout_impl_util.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"

Expand Down Expand Up @@ -106,7 +106,7 @@ struct DropoutParam {

int UpdateSeedAndIncrement(const phi::GPUContext& ctx, const int offset) {
uint64_t tmp_increment;
GetSeedDataAndIncrement(
phi::funcs::GetSeedDataAndIncrement(
ctx, tensor_seed, fix_seed, seed_val, offset, &seed, &tmp_increment);
increment = static_cast<int>(tmp_increment);
return increment;
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/platform/cuda_graph_with_memory_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"

#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/backends/all_context.h"

DECLARE_bool(use_stream_safe_cuda_allocator);

namespace paddle {
namespace platform {

#ifdef PADDLE_WITH_CUDA
void BeginCUDAGraphCapture(platform::CUDAPlace place,
void BeginCUDAGraphCapture(phi::GPUPlace place,
cudaStreamCaptureMode mode,
int64_t pool_id) {
auto* mutable_dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
dev_ctx->cudnn_workspace_handle().ResetWorkspace();

Expand Down Expand Up @@ -64,7 +64,7 @@ void BeginCUDAGraphCapture(platform::CUDAPlace place,

std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() {
auto place = CUDAGraph::CapturingPlace();
auto* mutable_dev_ctx = platform::DeviceContextPool::Instance().Get(place);
auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
dev_ctx->cudnn_workspace_handle().ResetWorkspace();
dev_ctx->SetCUDAGraphAllocator(nullptr);
Expand Down
111 changes: 13 additions & 98 deletions paddle/fluid/platform/cuda_graph_with_memory_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,123 +14,38 @@

#pragma once

#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#endif
#include "paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/macros.h"

namespace paddle {
namespace platform {

#ifdef PADDLE_WITH_CUDA
#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \
__kernel_func, \
__grid, \
__block, \
__sm_size, \
__stream, \
__seed_inc, \
__seed_expr, \
__offset_expr, \
...) \
do { \
if (::paddle::platform::CUDAGraph::IsThisThreadCapturing() && (__cond)) { \
using __Helper = \
::paddle::platform::IsSameKernelHelper<decltype(&__kernel_func), \
&__kernel_func>; \
auto *dev_ctx = \
::paddle::platform::DeviceContextPool::Instance().GetByPlace( \
::paddle::platform::CUDAGraph::CapturingPlace()); \
auto __set_seed_func = \
[=](::paddle::platform::CUDAKernelParams *__params, \
bool __check_only) -> bool { \
if (__check_only) { \
return __params->func() == &__kernel_func && \
__Helper::Compare(*__params, __VA_ARGS__); \
} \
auto &KERNEL_PARAMS = *__params; \
uint64_t __seed, __offset; \
::paddle::operators::GetSeedDataAndIncrement( \
*dev_ctx, nullptr, false, 0, __seed_inc, &__seed, &__offset); \
__seed_expr = static_cast<decltype(__seed_expr)>(__seed); \
__offset_expr = static_cast<decltype(__offset_expr)>(__offset); \
return true; \
}; \
::paddle::platform::CUDAGraph::RecordRandomKernelInfo(__set_seed_func); \
} \
__kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \
} while (0)
#else
#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \
__kernel_func, \
__grid, \
__block, \
__sm_size, \
__stream, \
__seed_inc, \
__seed_expr, \
__offset_expr, \
...) \
do { \
__kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \
} while (0)
#endif

// NOTE: These APIs are not thread-safe.
#ifdef PADDLE_WITH_CUDA
void BeginCUDAGraphCapture(platform::CUDAPlace place,
using CUDAGraph = phi::backends::gpu::CUDAGraph;

void BeginCUDAGraphCapture(phi::GPUPlace place,
cudaStreamCaptureMode mode,
int64_t pool_id = CUDAGraph::kInvalidPoolID);
std::unique_ptr<CUDAGraph> EndCUDAGraphCapture();
#endif

inline bool IsCUDAGraphCapturing() {
#ifdef PADDLE_WITH_CUDA
return CUDAGraph::IsCapturing();
#else
return false;
#endif
}

inline platform::CUDAPlace CUDAGraphCapturingPlace() {
inline phi::GPUPlace CUDAGraphCapturingPlace() {
#ifdef PADDLE_WITH_CUDA
return CUDAGraph::CapturingPlace();
#else
PADDLE_THROW(platform::errors::Unimplemented(
PADDLE_THROW(phi::errors::Unimplemented(
"CUDA Graph is only supported on NVIDIA GPU device."));
#endif
}

// Add reset callback if CUDA Graph is capturing.
// Otherwise, invoke callback directly.
template <typename Callback>
inline void AddResetCallbackIfCapturingCUDAGraph(Callback &&callback) {
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(IsCUDAGraphCapturing())) {
return CUDAGraph::AddResetCallbackDuringCapturing(
std::forward<Callback>(callback));
}
#endif
callback();
}
using phi::backends::gpu::IsCUDAGraphCapturing;

template <typename T>
inline T *RestoreHostMemIfCapturingCUDAGraph(T *host_mem, size_t size) {
static_assert(std::is_trivial<T>::value, "T must be trivial type");
static_assert(!std::is_same<T, void>::value, "T cannot be void");
#ifdef PADDLE_WITH_CUDA
if (UNLIKELY(IsCUDAGraphCapturing())) {
size_t nbytes = size * sizeof(T);
void *new_host_mem = new uint8_t[nbytes];
std::memcpy(new_host_mem, host_mem, nbytes);
AddResetCallbackIfCapturingCUDAGraph(
[new_host_mem] { delete[] reinterpret_cast<uint8_t *>(new_host_mem); });
return reinterpret_cast<T *>(new_host_mem);
}
#endif
return host_mem;
}
using phi::backends::gpu::AddResetCallbackIfCapturingCUDAGraph;

using phi::backends::gpu::RestoreHostMemIfCapturingCUDAGraph;

class SkipCUDAGraphCaptureGuard {
DISABLE_COPY_AND_ASSIGN(SkipCUDAGraphCaptureGuard);
Expand Down
78 changes: 0 additions & 78 deletions paddle/fluid/platform/device/gpu/cuda/cuda_graph.h

This file was deleted.

4 changes: 2 additions & 2 deletions paddle/fluid/platform/device/gpu/gpu_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ limitations under the License. */
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/miopen.h"
#else
#include "paddle/fluid/platform/device/gpu/cuda/cuda_graph.h"
#include "paddle/fluid/platform/dynload/cudnn.h"
#include "paddle/phi/backends/gpu/cuda/cuda_graph.h"
#endif

#ifdef PADDLE_WITH_CUDA
Expand Down Expand Up @@ -230,7 +230,7 @@ class RecordedGpuMallocHelper {
result = hipMalloc(ptr, size);
}
#else
CUDAGraphCaptureModeGuard capture_mode_guard;
phi::backends::gpu::CUDAGraphCaptureModeGuard capture_mode_guard;
if (UNLIKELY(malloc_managed_memory)) {
result = cudaMallocManaged(ptr, size);
} else {
Expand Down
Loading