diff --git a/cmake/generic.cmake b/cmake/generic.cmake index f1dd2310a697a..f910e67962f72 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -358,7 +358,7 @@ function(hip_library TARGET_NAME) else() add_library(${TARGET_NAME} STATIC ${_cmake_options} ${_generated_files} ${_sources}) set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE CXX) - target_link_libraries(${TARGET_NAME} /opt/rocm/hip/lib/libhip_hcc.so /opt/rocm/hip/lib/libhip_device.a /opt/rocm/rccl/lib/librccl.so /opt/rocm/hiprand/lib/libhiprand.so) + target_link_libraries(${TARGET_NAME} "-Wl,--no-as-needed" /opt/rocm/hip/lib/libhip_hcc.so /opt/rocm/rccl/lib/librccl.so /opt/rocm/hiprand/lib/libhiprand.so /opt/rocm/hip/lib/libhip_device.a) find_fluid_modules(${TARGET_NAME}) endif() if("${hip_library_DEPS}" MATCHES "ARCHIVE_START") diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index 800ee690ebae6..eecf7f6e2e3a3 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -95,6 +95,7 @@ void AllReduceOpHandle::RunImpl() { auto stream = nccl_ctx.stream(); auto comm = nccl_ctx.comm_; #ifdef PADDLE_WITH_HIP + stream = static_cast(dev_ctxes_[p])->stream(); all_reduce_calls.emplace_back([=] { PADDLE_ENFORCE(platform::dynload::rcclAllReduce( buffer, buffer, numel, static_cast(dtype), diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cu b/paddle/fluid/framework/details/all_reduce_op_handle.cu new file mode 100644 index 0000000000000..828f7bcdfd52e --- /dev/null +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cu @@ -0,0 +1,323 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include + +#include "paddle/fluid/framework/details/all_reduce_op_handle.h" +#include "paddle/fluid/framework/details/container_cast.h" +#include "paddle/fluid/framework/details/reduce_and_gather.h" +#include "paddle/fluid/framework/details/variable_visitor.h" +#include "paddle/fluid/platform/profiler.h" + +DEFINE_bool(allreduce_check, false, "If set, check allreduce result."); +DEFINE_bool(allreduce_single_stream, true, "Batch size of input data"); +DEFINE_bool(allreduce_use_cpu, false, "use cpu to perform allreduce"); +DEFINE_int32(allreduce_thread, 512, "Batch size of input data"); +DEFINE_int32(allreduce_grid, 64, "Batch size of input data"); + +namespace paddle { +namespace framework { +namespace details { + +template +__global__ void allreduce_sum(size_t lens, T* A, T* B) { + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < lens ; idx += blockDim.x * gridDim.x) + { + A[idx] += B[idx]; + } +} + +template +__global__ void allreduce_sum(size_t lens, T* A, T* B, T* C) { + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < lens ; idx += blockDim.x * gridDim.x) + { + A[idx] = A[idx] + B[idx] + C[idx] ; + } +} + +template +__global__ void allreduce_sum(size_t lens, T* A, T* B,T* C, T* D) { + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < lens ; idx += blockDim.x * gridDim.x) + { + A[idx] = A[idx] + B[idx] + C[idx] + D[idx]; + } +} +template +__global__ void allreduce_sum(size_t lens, T* A, T* B,T* C, T* D, T* E) { + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < lens ; idx += blockDim.x * gridDim.x) + { + A[idx] = A[idx] + B[idx] + C[idx] + D[idx] + E[idx]; + } +} +template +__global__ void allreduce_sum(size_t lens, T* A, T* B,T* C, T* D, + T* E, T* F) { + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < lens ; idx += blockDim.x * gridDim.x) + { + A[idx] = A[idx] + B[idx] + C[idx] + D[idx] + E[idx] + F[idx]; + } +} +template +__global__ void allreduce_sum(size_t lens, T* A, T* B,T* C, T* D, + T* E, T* F,T* G) { + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < lens ; idx += blockDim.x * gridDim.x) + { + A[idx] = A[idx] + B[idx] + C[idx] + D[idx] + E[idx] + F[idx] + G[idx]; + } +} +template +__global__ void allreduce_sum(size_t lens, T* A, T* B,T* C, T* D, + T* E, T* F,T* G, T* H) { + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + for (; idx < lens ; idx += blockDim.x * gridDim.x) + { + A[idx] = A[idx] + B[idx] + C[idx] + D[idx] + E[idx] + F[idx] + G[idx] + H[idx]; + } +} +#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) +AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, + const std::vector &local_scopes, + const std::vector &places, + const platform::NCCLContextMap *ctxs) + : OpHandleBase(node), + local_scopes_(local_scopes), + places_(places), + nccl_ctxs_(ctxs) { + if (nccl_ctxs_) { + for (auto &p : places_) { + this->dev_ctxes_[p] = nccl_ctxs_->DevCtx(p); + } + } +} +#else +AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, + const std::vector &local_scopes, + const std::vector &places) + : OpHandleBase(node), local_scopes_(local_scopes), places_(places) {} +#endif + +void AllReduceOpHandle::RunImpl() { + platform::RecordEvent r("all_reduce", nullptr); + if (NoDummyInputSize() == 1) { + return; // No need to all reduce when GPU count = 1; + } else { + // Wait input done + WaitInputVarGenerated(); + auto in_var_handles = DynamicCast(this->Inputs()); + auto out_var_handles = DynamicCast(this->Outputs()); + PADDLE_ENFORCE_EQ( + in_var_handles.size(), places_.size(), + "The NoDummyInputSize should be equal to the number of places."); + PADDLE_ENFORCE_EQ( + in_var_handles.size(), out_var_handles.size(), + "The NoDummyInputSize and NoDummyOutputSize should be equal."); + + std::vector lod_tensors; + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto *s = local_scopes_[i]; + auto &local_scope = *s->FindVar(kLocalExecScopeName)->Get(); + auto &lod_tensor = + local_scope.FindVar(in_var_handles[i]->name_)->Get(); + lod_tensors.emplace_back(&lod_tensor); + PADDLE_ENFORCE_EQ(in_var_handles[i]->name_, out_var_handles[i]->name_, + "The name of input and output should be equal."); + } + + VLOG(1) << "device num " << local_scopes_.size() << " size " << lod_tensors[0]->numel(); + if (platform::is_gpu_place(lod_tensors[0]->place())) { + const bool use_cpu = FLAGS_allreduce_use_cpu || (local_scopes_.size() > 8); + const bool check = FLAGS_allreduce_check && !use_cpu; + const int threads = FLAGS_allreduce_thread; + const int grid = FLAGS_allreduce_grid; +#if defined(PADDLE_WITH_HIP) + size_t numel = lod_tensors[0]->numel(); + std::vector buffers; + std::vector streams; + framework::Tensor output_cpu[local_scopes_.size()]; + hipEvent_t events[local_scopes_.size()]; + framework::Tensor final_output_cpu; + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &p = places_[i]; + auto &lod_tensor = *lod_tensors[i]; + float* buffer = const_cast(lod_tensor.data()); + buffers.emplace_back(buffer); + + int dev_id = boost::get(p).device; + auto &nccl_ctx = nccl_ctxs_->at(dev_id); + //auto stream = nccl_ctx.stream(); + auto stream = static_cast(dev_ctxes_[p])->stream(); + if (FLAGS_allreduce_single_stream) + { + streams.emplace_back(stream); + } + else + { + streams.emplace_back(nccl_ctx.stream()); + hipStreamSynchronize(stream); + } + + if (use_cpu || check) + { + output_cpu[i].mutable_data(lod_tensor.dims(), platform::CPUPlace()); + framework::TensorCopySync(lod_tensor, platform::CPUPlace(), &(output_cpu[i])); + } + } + if (use_cpu) + { + //merge all original data to final_output_cpu at CPU side + final_output_cpu.mutable_data(lod_tensors[0]->dims(), platform::CPUPlace()); + for (size_t i = 0; i < local_scopes_.size(); ++i) { + for (int j = 0; j < output_cpu[i].numel() ; j++) + if(i == 0) + (final_output_cpu.data())[j] = (output_cpu[i].data())[j]; + else + (final_output_cpu.data())[j] += (output_cpu[i].data())[j]; + } + + //copy merged data back to different GPUs + for (size_t i = 0; i < local_scopes_.size(); ++i) { + auto &lod_tensor = const_cast(*lod_tensors[i]); + auto &p = places_[i]; + + framework::TensorCopySync(final_output_cpu, p, &lod_tensor); + } + } + else { + + static int ring_index = 0 ; + int idx[local_scopes_.size()]; + + for (size_t i = 0; i < local_scopes_.size(); ++i){ + idx[i] = (ring_index+i) % local_scopes_.size(); + hipEventCreate(&events[i]); + } + + /*sync all stream*/ + for (size_t i = 1; i < local_scopes_.size(); ++i){ + hipEventRecord(events[idx[i]], streams[idx[i]]); + hipStreamWaitEvent(streams[idx[0]], events[idx[i]], 0); + } + + if( local_scopes_.size() == 2 ) + hipLaunchKernelGGL((allreduce_sum), dim3(grid), dim3(threads), 0, streams[idx[0]], numel, + buffers[idx[0]], buffers[idx[1]]); + else if( local_scopes_.size() == 3 ) + hipLaunchKernelGGL((allreduce_sum), dim3(grid), dim3(threads), 0, streams[idx[0]], numel, + buffers[idx[0]], buffers[idx[1]], buffers[idx[2]]); + else if( local_scopes_.size() == 4 ) + hipLaunchKernelGGL((allreduce_sum), dim3(grid), dim3(threads), 0, streams[idx[0]], numel, + buffers[idx[0]], buffers[idx[1]], buffers[idx[2]], buffers[idx[3]]); + else if( local_scopes_.size() == 5 ) + hipLaunchKernelGGL((allreduce_sum), dim3(grid), dim3(threads), 0, streams[idx[0]], numel, + buffers[idx[0]], buffers[idx[1]], buffers[idx[2]], buffers[idx[3]], + buffers[idx[4]]); + else if( local_scopes_.size() == 6 ) + hipLaunchKernelGGL((allreduce_sum), dim3(grid), dim3(threads), 0, streams[idx[0]], numel, + buffers[idx[0]], buffers[idx[1]], buffers[idx[2]], buffers[idx[3]], + buffers[idx[4]], buffers[idx[5]]); + else if( local_scopes_.size() == 7 ) + hipLaunchKernelGGL((allreduce_sum), dim3(grid), dim3(threads), 0, streams[idx[0]], numel, + buffers[idx[0]], buffers[idx[1]], buffers[idx[2]], buffers[idx[3]], + buffers[idx[4]], buffers[idx[5]], buffers[idx[6]]); + else if( local_scopes_.size() == 8 ) + hipLaunchKernelGGL((allreduce_sum), dim3(grid), dim3(threads), 0, streams[idx[0]], numel, + buffers[idx[0]], buffers[idx[1]], buffers[idx[2]], buffers[idx[3]], + buffers[idx[4]], buffers[idx[5]], buffers[idx[6]], buffers[idx[7]]); + + /*broadcast results to all gpus */ + + for (size_t dst_i = 1, src_i = 0, ready_count = 1; dst_i < local_scopes_.size(); ++dst_i){ + hipMemcpyAsync(buffers[idx[dst_i]], buffers[idx[src_i]], numel * sizeof(float), hipMemcpyDeviceToDevice, streams[idx[src_i]]); + hipEventRecord(events[idx[dst_i]], streams[idx[src_i]]); + hipStreamWaitEvent(streams[idx[dst_i]], events[idx[dst_i]], 0); + if((--ready_count) == 0){ + ready_count = dst_i+1; + src_i = 0; + } + else src_i++; + } + + ring_index = (ring_index+1) % local_scopes_.size(); + + if (!FLAGS_allreduce_single_stream) + for (size_t i = 0; i < local_scopes_.size(); ++i) + hipStreamSynchronize(streams[i]); + + for (size_t i = 0; i < local_scopes_.size(); ++i) + hipEventDestroy(events[i]); + } + if (check) + { + final_output_cpu.mutable_data(lod_tensors[0]->dims(), platform::CPUPlace()); + for (int k = 0; k < local_scopes_.size(); ++k) + { + VLOG(1) << "checking " << k; + VLOG(1) << " org " << (output_cpu[k].data())[0]; + framework::TensorCopySync(*lod_tensors[k], platform::CPUPlace(), &(final_output_cpu)); + for (int j = 0; j < numel ; j++) { + float temp = 0.0f; + for (size_t i = 0; i < local_scopes_.size(); ++i) + temp += (output_cpu[i].data())[j]; + if ((j == 0) || ( temp - (final_output_cpu.data())[j] > 1e-3 || temp - (final_output_cpu.data())[j] < -1e-3)) + VLOG(1) << "data " << temp << " " << (final_output_cpu.data())[j]; + } + } + } +#else + PADDLE_THROW("Not compiled with CUDA"); +#endif + } else { // Special handle CPU only Operator's gradient. Like CRF + auto &trg = *this->local_scopes_[0] + ->FindVar(kLocalExecScopeName) + ->Get() + ->FindVar(out_var_handles[0]->name_) + ->GetMutable(); + + // Reduce All Tensor to trg in CPU + ReduceLoDTensor func(lod_tensors, &trg); + VisitDataType(ToDataType(lod_tensors[0]->type()), func); + + for (size_t i = 1; i < local_scopes_.size(); ++i) { + auto &scope = + *local_scopes_[i]->FindVar(kLocalExecScopeName)->Get(); + auto &p = places_[i]; + auto *var = scope.FindVar(out_var_handles[i]->name_); + auto *dev_ctx = dev_ctxes_[p]; + + RunAndRecordEvent(p, [&trg, var, dev_ctx, p] { + auto &tensor_gpu = *var->GetMutable(); + auto &tensor_cpu = trg; + TensorCopy(tensor_cpu, p, *dev_ctx, &tensor_gpu); + }); + } + } + } +} + +std::string AllReduceOpHandle::Name() const { return "all_reduce"; } +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/fetch_op_handle.cc b/paddle/fluid/framework/details/fetch_op_handle.cc index 6ae7d764e3382..604fd1db48e76 100644 --- a/paddle/fluid/framework/details/fetch_op_handle.cc +++ b/paddle/fluid/framework/details/fetch_op_handle.cc @@ -48,7 +48,9 @@ void FetchOpHandle::WaitAndMergeCPUTensors() const { } void FetchOpHandle::RunImpl() { +#ifndef PADDLE_WITH_HIP WaitInputVarGenerated(platform::CPUPlace()); +#endif tensors_.resize(inputs_.size()); platform::CPUPlace cpu; @@ -65,8 +67,10 @@ void FetchOpHandle::RunImpl() { auto &t = var->Get(); if (platform::is_gpu_place(t.place())) { -#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) +#ifdef PADDLE_WITH_CUDA TensorCopySync(t, cpu, &tensors_[i]); +#elif defined(PADDLE_WITH_HIP) + TensorCopy(t, cpu, &tensors_[i]); #endif } else { tensors_[i].ShareDataWith(t); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index c0851b09b4aaa..186bcc27f2e16 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -476,7 +476,7 @@ bool MultiDevSSAGraphBuilder::IsSparseGradient(const std::string &og) const { void MultiDevSSAGraphBuilder::SetCommunicationContext( OpHandleBase *op_handle, const platform::Place &p) const { -#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) +#if (defined(PADDLE_WITH_CUDA)) if (nccl_ctxs_ == nullptr) { op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); diff --git a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc index 5bd974d6b789a..5daf5e3e7192c 100644 --- a/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.cc @@ -69,9 +69,11 @@ FeedFetchList ScopeBufferedSSAGraphExecutor::Run( drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) { drop_scope_counter_ = 0; // Wait All computational streams +#ifndef PADDLE_WITH_HIP for (auto p : places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); } +#endif for (auto &scope : local_scopes_) { auto &local_scope = *scope->Var(details::kLocalExecScopeName)->GetMutable(); diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index 8a9aa6fa9aa71..4a13ec77add61 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -26,6 +26,7 @@ limitations under the License. */ #endif #ifdef PADDLE_WITH_HIP #include "paddle/fluid/platform/rccl_helper.h" +#define USE_MEM_COPY 1 #endif #include "paddle/fluid/framework/details/multi_devices_graph_check_pass.h" @@ -246,12 +247,16 @@ void ParallelExecutor::BCastParamsToDevices( if (paddle::platform::is_gpu_place(main_tensor.place())) { #if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) std::vector buffers; +#ifndef USE_MEM_COPY size_t numel = main_tensor.numel(); +#endif #ifdef PADDLE_WITH_CUDA ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); #endif #ifdef PADDLE_WITH_HIP +#ifndef USE_MEM_COPY rcclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); +#endif #endif for (size_t i = 0; i < member_->places_.size(); ++i) { auto place = member_->places_[i]; @@ -259,12 +264,19 @@ void ParallelExecutor::BCastParamsToDevices( if ((initializing && i == 0) || (!initializing && static_cast(i) == var_dev_id)) { +#ifndef USE_MEM_COPY buffer = const_cast(main_tensor.data()); +#endif } else { auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var)->GetMutable(); t->Resize(dims); +#ifdef USE_MEM_COPY + t->mutable_data(place, main_tensor.type()); + framework::TensorCopy(main_tensor, place, t); +#else buffer = t->mutable_data(place, main_tensor.type()); +#endif } buffers.push_back(buffer); } @@ -289,6 +301,7 @@ void ParallelExecutor::BCastParamsToDevices( } #endif #ifdef PADDLE_WITH_HIP +#ifndef USE_MEM_COPY for (size_t i = 0; i < member_->places_.size(); ++i) { auto &nccl_ctx = member_->nccl_ctxs_->at(member_->places_[i]); if (initializing) { @@ -302,6 +315,7 @@ void ParallelExecutor::BCastParamsToDevices( } } } +#endif #endif member_->nccl_ctxs_->WaitAll(); } diff --git a/paddle/fluid/operators/adam_op.h b/paddle/fluid/operators/adam_op.h index 84a584f424823..4a989886e5aa6 100644 --- a/paddle/fluid/operators/adam_op.h +++ b/paddle/fluid/operators/adam_op.h @@ -295,14 +295,14 @@ class AdamOpKernel : public framework::OpKernel { int64_t* rows = nullptr; // When compiled without CUDA, the CUDAMutableData() interface should not be // provided. -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (platform::is_gpu_place(ctx.GetPlace())) { rows = grad_merge.mutable_rows()->CUDAMutableData(ctx.GetPlace()); } else { #endif rows = grad_merge.mutable_rows()->data(); -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) } #endif auto row_numel = grad_tensor.numel() / grad_merge.rows().size(); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index d9c1b44d3e28f..2471428680f8f 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/float16.h" +#define THURST_RANDOM_LIMIT_SIZE 4000000 + namespace paddle { namespace operators { @@ -56,6 +58,27 @@ __global__ void RandomGenerator(const size_t n, const int seed, } } +template +__global__ void RandomGenerator(const size_t n, const int seed, + const float dropout_prob, const T* src, + T* mask_data, T* dst, float* random_data) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + + T mask; + T dest; + for (; idx < n; idx += blockDim.x * gridDim.x) { + T s = src[idx]; + if (random_data[idx] < dropout_prob) { + mask = static_cast(0); + } else { + mask = static_cast(1); + } + dest = s * mask; + mask_data[idx] = mask; + dst[idx] = dest; + } +} + // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. @@ -82,8 +105,24 @@ class GPUDropoutKernel : public framework::OpKernel { int threads = 512; int grid = (x->numel() + threads - 1) / threads; - hipLaunchKernelGGL((RandomGenerator), dim3(grid), dim3(threads), 0, context.cuda_device_context().stream(), - size, seed, dropout_prob, x_data, mask_data, y_data); + +#if defined(PADDLE_WITH_HIP) + if(size > THURST_RANDOM_LIMIT_SIZE){ + //large size, generate random buffer one-shot to improve performance + framework::Tensor random; + auto* random_data = random.mutable_data(mask->dims(), context.GetPlace()); + hiprandGenerator_t generator = context.cuda_device_context().rand_generator(); + PADDLE_ENFORCE(platform::dynload::hiprandSetPseudoRandomGeneratorSeed(generator, seed)); + PADDLE_ENFORCE(platform::dynload::hiprandGenerateUniform(generator, random_data, size)); + hipLaunchKernelGGL((RandomGenerator), dim3(grid), dim3(threads), 0, context.cuda_device_context().stream(), + size, seed, dropout_prob, x_data, mask_data, y_data, random_data); + } else { +#endif + hipLaunchKernelGGL((RandomGenerator), dim3(grid), dim3(threads), 0, context.cuda_device_context().stream(), + size, seed, dropout_prob, x_data, mask_data, y_data); +#if defined(PADDLE_WITH_HIP) + } +#endif } else { auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 0628b4b826d27..7f03e0e75ba25 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/dynload/hiprand.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/platform/device_context.cc b/paddle/fluid/platform/device_context.cc index 00e7ee17aea27..ffe92722f0222 100644 --- a/paddle/fluid/platform/device_context.cc +++ b/paddle/fluid/platform/device_context.cc @@ -288,6 +288,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) { } else { miopen_handle_ = nullptr; } + hiprand_generator = nullptr; } CUDADeviceContext::~CUDADeviceContext() { @@ -300,6 +301,9 @@ CUDADeviceContext::~CUDADeviceContext() { eigen_stream_.reset(); eigen_device_.reset(); PADDLE_ENFORCE(hipStreamDestroy(stream_)); + if( hiprand_generator != nullptr ){ + PADDLE_ENFORCE(dynload::hiprandDestroyGenerator(hiprand_generator)); + } } Place CUDADeviceContext::GetPlace() const { return place_; } @@ -330,6 +334,13 @@ miopenHandle_t CUDADeviceContext::miopen_handle() const { return miopen_handle_; hipStream_t CUDADeviceContext::stream() const { return stream_; } +hiprandGenerator_t CUDADeviceContext::rand_generator() const { + if( hiprand_generator == nullptr ){ + PADDLE_ENFORCE(dynload::hiprandCreateGenerator((hiprandGenerator_t *)&hiprand_generator, HIPRAND_RNG_PSEUDO_DEFAULT)); + } + return hiprand_generator; +} + CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() { eigen_device_.reset(new Eigen::DefaultDevice()); } diff --git a/paddle/fluid/platform/device_context.h b/paddle/fluid/platform/device_context.h index 7c425d4004dca..b4d5ce4d19f3b 100644 --- a/paddle/fluid/platform/device_context.h +++ b/paddle/fluid/platform/device_context.h @@ -26,6 +26,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_HIP #include "paddle/fluid/platform/dynload/hipblas.h" #include "paddle/fluid/platform/dynload/miopen.h" +#include "paddle/fluid/platform/dynload/hiprand.h" #include "paddle/fluid/platform/gpu_info.h" #define EIGEN_USE_GPU #endif @@ -185,6 +186,8 @@ class CUDADeviceContext : public DeviceContext { /*! \brief Return cuda stream in the device context. */ hipStream_t stream() const; + hiprandGenerator_t rand_generator() const; + template void RecordEvent(hipEvent_t ev, Callback callback) { std::lock_guard guard(mutex_); @@ -202,6 +205,7 @@ class CUDADeviceContext : public DeviceContext { hipStream_t stream_; miopenHandle_t miopen_handle_; hipblasHandle_t hipblas_handle_; + hiprandGenerator_t hiprand_generator; int compute_capability; int multi_process; int max_threads_per_mp; diff --git a/paddle/fluid/platform/gpu_info_hip.cc b/paddle/fluid/platform/gpu_info_hip.cc index 0422b17f17168..a1a9a54e6f98b 100644 --- a/paddle/fluid/platform/gpu_info_hip.cc +++ b/paddle/fluid/platform/gpu_info_hip.cc @@ -134,16 +134,14 @@ void GpuMemcpySync(void *dst, const void *src, size_t count, void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src, int src_device, size_t count, hipStream_t stream) { - PADDLE_ENFORCE( - hipMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream), - "hipMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeerAsync"); + PADDLE_ENFORCE(hipMemcpyAsync(dst, src, count, hipMemcpyDeviceToDevice, stream), + "hipMemcpyAsync failed in paddle::platform::GpuMemcpyPeerAsync"); } void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src, int src_device, size_t count) { - PADDLE_ENFORCE( - hipMemcpyPeer(dst, dst_device, src, src_device, count), - "cudaMemcpyPeer failed in paddle::platform::GpuMemcpyPeerSync"); + PADDLE_ENFORCE(hipMemcpy(dst, src, count, hipMemcpyDeviceToDevice), + "hipMemcpy failed in paddle::platform::GpuMemapyPeerSync"); } void GpuMemsetAsync(void *dst, int value, size_t count, hipStream_t stream) {