From cced930b61ba246dffec68bbe09bd9e22a142d64 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Tue, 23 Feb 2021 13:49:54 +0800 Subject: [PATCH] [ROCM] update fluid operators for rocm (part1), test=develop (#31077) --- .../controlflow/conditional_block_op.h | 2 +- .../operators/controlflow/get_places_op.cc | 4 +- .../operators/controlflow/while_op_helper.cc | 2 +- .../fluid/operators/detection/CMakeLists.txt | 8 ++-- .../fluid/operators/detection/bbox_util.cu.h | 21 +++++++++- .../detection/collect_fpn_proposals_op.cu | 40 ++++++++++++++++--- .../detection/distribute_fpn_proposals_op.cu | 32 +++++++++++++-- .../detection/sigmoid_focal_loss_op.cu | 1 - .../operators/detection/target_assign_op.h | 4 +- .../operators/distributed/CMakeLists.txt | 2 +- .../distributed/brpc/brpc_sendrecvop_utils.cc | 7 +++- .../distributed/brpc/brpc_serde_test.cc | 4 +- .../operators/distributed/grpc/grpc_serde.cc | 7 +++- .../distributed/grpc/grpc_serde_test.cc | 4 +- .../distributed/parameter_prefetch.cc | 2 +- .../operators/distributed/sendrecvop_utils.cc | 2 +- .../distributed/variable_response.cc | 6 +-- paddle/fluid/operators/metrics/accuracy_op.cu | 13 +++++- paddle/fluid/operators/metrics/auc_op.cu | 17 ++++++++ 19 files changed, 142 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/controlflow/conditional_block_op.h b/paddle/fluid/operators/controlflow/conditional_block_op.h index c8ab2c91e9122..b9ea2ade6cb90 100644 --- a/paddle/fluid/operators/controlflow/conditional_block_op.h +++ b/paddle/fluid/operators/controlflow/conditional_block_op.h @@ -73,7 +73,7 @@ class ConditionalOp : public framework::OperatorBase { ips[0]->numel())); bool res = false; if (platform::is_gpu_place(ips[0]->place())) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) framework::LoDTensor cpu_tensor; framework::TensorCopy(*ips[0], platform::CPUPlace(), &cpu_tensor); platform::DeviceContextPool::Instance().Get(ips[0]->place())->Wait(); diff --git a/paddle/fluid/operators/controlflow/get_places_op.cc b/paddle/fluid/operators/controlflow/get_places_op.cc index 2bab8e57916ef..dec0e789776a4 100644 --- a/paddle/fluid/operators/controlflow/get_places_op.cc +++ b/paddle/fluid/operators/controlflow/get_places_op.cc @@ -26,7 +26,7 @@ namespace imperative { class OpBase; } // namespace imperative } // namespace paddle -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/gpu_info.h" #endif @@ -34,7 +34,7 @@ namespace paddle { namespace operators { static size_t CUDADevCount() { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) return platform::GetCUDADeviceCount(); #else return 0UL; diff --git a/paddle/fluid/operators/controlflow/while_op_helper.cc b/paddle/fluid/operators/controlflow/while_op_helper.cc index 904cc214edd09..c9d4e1510985f 100644 --- a/paddle/fluid/operators/controlflow/while_op_helper.cc +++ b/paddle/fluid/operators/controlflow/while_op_helper.cc @@ -223,7 +223,7 @@ bool GetCondData(const framework::LoDTensor &cond) { } // when platform::is_gpu_place(cond.place()) is true std::unique_ptr cpu_cond{new framework::LoDTensor()}; -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) framework::TensorCopySync(cond, platform::CPUPlace(), cpu_cond.get()); #else PADDLE_THROW(platform::errors::PreconditionNotMet( diff --git a/paddle/fluid/operators/detection/CMakeLists.txt b/paddle/fluid/operators/detection/CMakeLists.txt index 1915323f3c324..efbd653ffd3b0 100644 --- a/paddle/fluid/operators/detection/CMakeLists.txt +++ b/paddle/fluid/operators/detection/CMakeLists.txt @@ -40,10 +40,12 @@ detection_library(box_decoder_and_assign_op SRCS box_decoder_and_assign_op.cc bo detection_library(sigmoid_focal_loss_op SRCS sigmoid_focal_loss_op.cc sigmoid_focal_loss_op.cu) detection_library(retinanet_detection_output_op SRCS retinanet_detection_output_op.cc) -if(WITH_GPU) +if(WITH_GPU OR WITH_ROCM) set(TMPDEPS memory) - if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) - set(TMPDEPS memory cub) + if(WITH_GPU) + if (${CMAKE_CUDA_COMPILER_VERSION} LESS 11.0) + set(TMPDEPS memory cub) + endif() endif() detection_library(generate_proposals_op SRCS generate_proposals_op.cc generate_proposals_op.cu DEPS ${TMPDEPS}) detection_library(generate_proposals_v2_op SRCS generate_proposals_v2_op.cc generate_proposals_v2_op.cu DEPS ${TMPDEPS}) diff --git a/paddle/fluid/operators/detection/bbox_util.cu.h b/paddle/fluid/operators/detection/bbox_util.cu.h index 0247093d03a91..0d52fd4161382 100644 --- a/paddle/fluid/operators/detection/bbox_util.cu.h +++ b/paddle/fluid/operators/detection/bbox_util.cu.h @@ -16,10 +16,16 @@ limitations under the License. */ #include #include #include +#ifdef __NVCC__ #include "cub/cub.cuh" +#include "paddle/fluid/platform/cudnn_helper.h" +#endif +#ifdef __HIPCC__ +#include +#include "paddle/fluid/platform/miopen_helper.h" +#endif #include "paddle/fluid/operators/gather.cu.h" #include "paddle/fluid/operators/math/math_function.h" -#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/for_range.h" namespace paddle { @@ -58,16 +64,27 @@ static void SortDescending(const platform::CUDADeviceContext &ctx, // Determine temporary device storage requirements size_t temp_storage_bytes = 0; +#ifdef PADDLE_WITH_HIP + hipcub::DeviceRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num); +#else cub::DeviceRadixSort::SortPairsDescending( nullptr, temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num); +#endif // Allocate temporary storage auto place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); - // Run sorting operation +// Run sorting operation +#ifdef PADDLE_WITH_HIP + hipcub::DeviceRadixSort::SortPairsDescending( + d_temp_storage->ptr(), temp_storage_bytes, keys_in, keys_out, idx_in, + idx_out, num); +#else cub::DeviceRadixSort::SortPairsDescending( d_temp_storage->ptr(), temp_storage_bytes, keys_in, keys_out, idx_in, idx_out, num); +#endif } template diff --git a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu index 86207052bb2be..4bb0f9ca67fb2 100644 --- a/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu +++ b/paddle/fluid/operators/detection/collect_fpn_proposals_op.cu @@ -9,8 +9,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +#endif + +#include #include "paddle/fluid/framework/mixed_vector.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/memory/memcpy.h" @@ -135,17 +141,29 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel { // Determine temporary device storage requirements size_t temp_storage_bytes = 0; +#ifdef PADDLE_WITH_HIP + hipcub::DeviceRadixSort::SortPairsDescending( + nullptr, temp_storage_bytes, concat_scores.data(), keys_out, idx_in, + idx_out, total_roi_num); +#else cub::DeviceRadixSort::SortPairsDescending( nullptr, temp_storage_bytes, concat_scores.data(), keys_out, idx_in, idx_out, total_roi_num); +#endif // Allocate temporary storage auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); - // Run sorting operation - // sort score to get corresponding index +// Run sorting operation +// sort score to get corresponding index +#ifdef PADDLE_WITH_HIP + hipcub::DeviceRadixSort::SortPairsDescending( + d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data(), + keys_out, idx_in, idx_out, total_roi_num); +#else cub::DeviceRadixSort::SortPairsDescending( d_temp_storage->ptr(), temp_storage_bytes, concat_scores.data(), keys_out, idx_in, idx_out, total_roi_num); +#endif index_out_t.Resize({real_post_num}); Tensor sorted_rois; sorted_rois.mutable_data({real_post_num, kBBoxSize}, dev_ctx.GetPlace()); @@ -167,17 +185,29 @@ class GPUCollectFpnProposalsOpKernel : public framework::OpKernel { out_id_t.mutable_data({real_post_num}, dev_ctx.GetPlace()); // Determine temporary device storage requirements temp_storage_bytes = 0; +#ifdef PADDLE_WITH_HIP + hipcub::DeviceRadixSort::SortPairs( + nullptr, temp_storage_bytes, sorted_batch_id.data(), out_id_data, + batch_idx_in, index_out_t.data(), real_post_num); +#else cub::DeviceRadixSort::SortPairs( nullptr, temp_storage_bytes, sorted_batch_id.data(), out_id_data, batch_idx_in, index_out_t.data(), real_post_num); +#endif // Allocate temporary storage d_temp_storage = memory::Alloc(place, temp_storage_bytes); - // Run sorting operation - // sort batch_id to get corresponding index +// Run sorting operation +// sort batch_id to get corresponding index +#ifdef PADDLE_WITH_HIP + hipcub::DeviceRadixSort::SortPairs( + d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data(), + out_id_data, batch_idx_in, index_out_t.data(), real_post_num); +#else cub::DeviceRadixSort::SortPairs( d_temp_storage->ptr(), temp_storage_bytes, sorted_batch_id.data(), out_id_data, batch_idx_in, index_out_t.data(), real_post_num); +#endif GPUGather(dev_ctx, sorted_rois, index_out_t, fpn_rois); diff --git a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu index 7550ff91fd542..63f205947d9b5 100644 --- a/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu +++ b/paddle/fluid/operators/detection/distribute_fpn_proposals_op.cu @@ -12,8 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include +#ifdef __NVCC__ #include "cub/cub.cuh" +#endif +#ifdef __HIPCC__ +#include +#endif + +#include #include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/operators/detection/bbox_util.h" #include "paddle/fluid/operators/detection/distribute_fpn_proposals_op.h" @@ -143,24 +149,42 @@ class GPUDistributeFpnProposalsOpKernel : public framework::OpKernel { // Determine temporary device storage requirements size_t temp_storage_bytes = 0; +#ifdef PADDLE_WITH_HIP + hipcub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, + target_lvls_data, keys_out, + idx_in, idx_out, roi_num); +#else cub::DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, target_lvls_data, keys_out, idx_in, idx_out, roi_num); +#endif // Allocate temporary storage auto d_temp_storage = memory::Alloc(place, temp_storage_bytes); - // Run sorting operation - // sort target level to get corresponding index +// Run sorting operation +// sort target level to get corresponding index +#ifdef PADDLE_WITH_HIP + hipcub::DeviceRadixSort::SortPairs( + d_temp_storage->ptr(), temp_storage_bytes, target_lvls_data, keys_out, + idx_in, idx_out, roi_num); +#else cub::DeviceRadixSort::SortPairs( d_temp_storage->ptr(), temp_storage_bytes, target_lvls_data, keys_out, idx_in, idx_out, roi_num); +#endif int* restore_idx_data = restore_index->mutable_data({roi_num, 1}, dev_ctx.GetPlace()); - // sort current index to get restore index +// sort current index to get restore index +#ifdef PADDLE_WITH_HIP + hipcub::DeviceRadixSort::SortPairs( + d_temp_storage->ptr(), temp_storage_bytes, idx_out, keys_out, idx_in, + restore_idx_data, roi_num); +#else cub::DeviceRadixSort::SortPairs( d_temp_storage->ptr(), temp_storage_bytes, idx_out, keys_out, idx_in, restore_idx_data, roi_num); +#endif int start = 0; auto multi_rois_num = ctx.MultiOutput("MultiLevelRoIsNum"); diff --git a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu index f12d60c8b0fc0..ed1676200dc47 100644 --- a/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu +++ b/paddle/fluid/operators/detection/sigmoid_focal_loss_op.cu @@ -11,7 +11,6 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "cub/cub.cuh" #include "paddle/fluid/operators/detection/sigmoid_focal_loss_op.h" #include "paddle/fluid/operators/math.h" #include "paddle/fluid/platform/cuda_primitives.h" diff --git a/paddle/fluid/operators/detection/target_assign_op.h b/paddle/fluid/operators/detection/target_assign_op.h index da85e4c5e444c..01b15865e93b6 100644 --- a/paddle/fluid/operators/detection/target_assign_op.h +++ b/paddle/fluid/operators/detection/target_assign_op.h @@ -107,7 +107,7 @@ class TargetAssignKernel : public framework::OpKernel { int64_t k = x->dims()[2]; auto x_lod = x->lod().back(); -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) size_t* x_lod_data = x_lod.MutableData(ctx.GetPlace()); #else size_t* x_lod_data = x_lod.data(); @@ -129,7 +129,7 @@ class TargetAssignKernel : public framework::OpKernel { "TargetAssignOp input(NegIndices) needs 1 level of LoD")); const int* neg_idx_data = neg_indices->data(); auto neg_lod = neg_indices->lod().back(); -#if defined(PADDLE_WITH_CUDA) +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) size_t* neg_lod_data = neg_lod.MutableData(ctx.GetPlace()); #else size_t* neg_lod_data = neg_lod.data(); diff --git a/paddle/fluid/operators/distributed/CMakeLists.txt b/paddle/fluid/operators/distributed/CMakeLists.txt index 1417676426c2b..c9db6148bc45d 100644 --- a/paddle/fluid/operators/distributed/CMakeLists.txt +++ b/paddle/fluid/operators/distributed/CMakeLists.txt @@ -61,7 +61,7 @@ cc_library(parameter_send SRCS parameter_send.cc DEPS sendrecvop_rpc memory) cc_library(parameter_recv SRCS parameter_recv.cc DEPS sendrecvop_rpc memory) cc_library(communicator SRCS communicator.cc DEPS scope selected_rows tensor variable_helper selected_rows_functor simple_threadpool parameter_send parameter_recv generator) cc_test(communicator_test SRCS communicator_test.cc DEPS communicator) -if(WITH_GPU) +if(WITH_GPU OR WITH_ROCM) cc_test(collective_server_test SRCS collective_server_test.cc DEPS sendrecvop_rpc executor ${RPC_DEPS} selected_rows_functor scope math_function) diff --git a/paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.cc b/paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.cc index d66281ac7c7ae..411c0f36debd3 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.cc +++ b/paddle/fluid/operators/distributed/brpc/brpc_sendrecvop_utils.cc @@ -15,6 +15,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_NCCL #include #endif +#ifdef PADDLE_WITH_RCCL +#include +#endif #include #include #include @@ -144,7 +147,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, } else if (var->IsType()) { request->set_type(::sendrecv::SELECTED_ROWS); payload.reset(new TensorPayload(GetSelectedRowsPayload(var, ctx, request))); -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) } else if (var->IsType()) { request->set_type(::sendrecv::NCCL_ID); const ncclUniqueId& uid = var->Get(); @@ -172,7 +175,7 @@ void SerializeToIOBuf(const std::string& name, framework::Variable* var, static_cast(payload->ptr()), payload->memory_size()); } else { if (platform::is_gpu_place(ctx.GetPlace())) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) IOBufWriter::AppendZeroCopy( name, iobuf, ::sendrecv::VariableMessage::kSerializedFieldNumber, static_cast(payload->ptr()), payload->memory_size(), diff --git a/paddle/fluid/operators/distributed/brpc/brpc_serde_test.cc b/paddle/fluid/operators/distributed/brpc/brpc_serde_test.cc index b902d3db48778..bcf20ad076b11 100644 --- a/paddle/fluid/operators/distributed/brpc/brpc_serde_test.cc +++ b/paddle/fluid/operators/distributed/brpc/brpc_serde_test.cc @@ -159,7 +159,7 @@ void RunTestLodTensor(platform::Place place) { TEST(LodTensor, Run) { platform::CPUPlace place; RunTestLodTensor(place); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) platform::CUDAPlace gpu(0); RunTestLodTensor(gpu); #endif @@ -168,7 +168,7 @@ TEST(LodTensor, Run) { TEST(SelectedRows, Run) { platform::CPUPlace place; RunSerdeTestSelectedRows(place); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) platform::CUDAPlace gpu; RunSerdeTestSelectedRows(gpu); #endif diff --git a/paddle/fluid/operators/distributed/grpc/grpc_serde.cc b/paddle/fluid/operators/distributed/grpc/grpc_serde.cc index 13343ed4a78dd..0fc9b69577914 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_serde.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_serde.cc @@ -15,6 +15,9 @@ limitations under the License. */ #ifdef PADDLE_WITH_NCCL #include #endif +#ifdef PADDLE_WITH_RCCL +#include +#endif #include #include #include "grpcpp/impl/codegen/byte_buffer.h" @@ -75,7 +78,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, } else if (var->IsType()) { request.set_type(::sendrecv::SELECTED_ROWS); payload = new TensorPayload(GetSelectedRowsPayload(var, ctx, &request)); -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) } else if (var->IsType()) { request.set_type(::sendrecv::NCCL_ID); #endif @@ -91,7 +94,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, e.WriteRawBytes(std::string(header.data(), header.size())); // NCCLID is copied directly to the message, return bytebuffer // with only one slice if serializing NCCLID. -#ifdef PADDLE_WITH_NCCL +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) if (var->IsType()) { e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, NCCL_UNIQUE_ID_BYTES); diff --git a/paddle/fluid/operators/distributed/grpc/grpc_serde_test.cc b/paddle/fluid/operators/distributed/grpc/grpc_serde_test.cc index 749c1bf39a486..d407a72938a74 100644 --- a/paddle/fluid/operators/distributed/grpc/grpc_serde_test.cc +++ b/paddle/fluid/operators/distributed/grpc/grpc_serde_test.cc @@ -206,7 +206,7 @@ TEST(LodTensor, Run) { platform::CPUPlace place; RunTestLodTensor(place); RunTestLodTensor(place, 1); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) platform::CUDAPlace gpu(0); RunTestLodTensor(gpu); RunTestLodTensor(gpu, 1); @@ -217,7 +217,7 @@ TEST(SelectedRows, Run) { platform::CPUPlace place; RunSerdeTestSelectedRows(place); -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) platform::CUDAPlace gpu; RunSerdeTestSelectedRows(gpu); #endif diff --git a/paddle/fluid/operators/distributed/parameter_prefetch.cc b/paddle/fluid/operators/distributed/parameter_prefetch.cc index df47422fc059f..558d70e5c3353 100644 --- a/paddle/fluid/operators/distributed/parameter_prefetch.cc +++ b/paddle/fluid/operators/distributed/parameter_prefetch.cc @@ -281,7 +281,7 @@ void prefetchs(const std::vector &id_var_names, } } } else { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) std::vector ids_value_vec(ids_size * vec_dim_1); for (auto idx = 0; idx < static_cast(ids_size); idx++) { const auto &id = ids[idx]; diff --git a/paddle/fluid/operators/distributed/sendrecvop_utils.cc b/paddle/fluid/operators/distributed/sendrecvop_utils.cc index 39b4b3daf8c8c..107c74eb2670e 100644 --- a/paddle/fluid/operators/distributed/sendrecvop_utils.cc +++ b/paddle/fluid/operators/distributed/sendrecvop_utils.cc @@ -39,7 +39,7 @@ using VarMsg = sendrecv::VariableMessage; static TensorPayload GetCommunicationAllocationFromTensor( const platform::DeviceContext& ctx, const framework::Tensor& tensor) { if (is_gpu_place(ctx.GetPlace())) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) PADDLE_ENFORCE_EQ( is_gpu_place(tensor.place()), true, platform::errors::PreconditionNotMet("Please run in gpu place.")); diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc index 4c161f044d8d7..79b0843968e85 100644 --- a/paddle/fluid/operators/distributed/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -33,7 +33,7 @@ bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input, int total_written = 0; if (platform::is_gpu_place(place)) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto& gpu_dev_ctx = static_cast(dev_ctx); platform::CPUPlace cpu; @@ -62,7 +62,7 @@ bool VariableResponse::ReadRaw(::google::protobuf::io::CodedInputStream* input, gpu_dev_ctx.Wait(); #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "Unexpected branch, please compile with PADDLE_WITH_CUDA")); + "Unexpected branch, please compile with WITH_GPU or WITH_ROCM")); #endif return true; } else if (platform::is_xpu_place(place)) { @@ -221,7 +221,7 @@ bool VariableResponse::ProcSerializedField( platform::errors::PreconditionNotMet("meta info should be got first!")); if (meta_.type() == sendrecv::NCCL_ID) { -#ifdef PADDLE_WITH_CUDA +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) auto* var = scope_->FindVar(meta_.varname()); if (var != nullptr) { ncclUniqueId* id = var->GetMutable(); diff --git a/paddle/fluid/operators/metrics/accuracy_op.cu b/paddle/fluid/operators/metrics/accuracy_op.cu index ab5ee745aaf8b..3d22fc60993c7 100644 --- a/paddle/fluid/operators/metrics/accuracy_op.cu +++ b/paddle/fluid/operators/metrics/accuracy_op.cu @@ -43,8 +43,19 @@ __global__ void AccuracyCudaKernel(const int N, const int D, total[threadIdx.x] = count; __syncthreads(); - // reduce the count with init value 0, and output accuracy. +// reduce the count with init value 0, and output accuracy. +#ifdef PADDLE_WITH_CUDA int result = thrust::reduce(thrust::device, total, total + BlockSize, 0); +#else + // HIP thrust::reduce not support __device__ + for (int s = BlockSize / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + total[threadIdx.x] += total[threadIdx.x + s]; + } + __syncthreads(); + } + int result = total[0]; +#endif if (threadIdx.x == 0) { *correct_data = result; *accuracy = static_cast(result) / static_cast(N); diff --git a/paddle/fluid/operators/metrics/auc_op.cu b/paddle/fluid/operators/metrics/auc_op.cu index 13da4ff0857d9..40609381c17ae 100644 --- a/paddle/fluid/operators/metrics/auc_op.cu +++ b/paddle/fluid/operators/metrics/auc_op.cu @@ -130,6 +130,7 @@ class AucCUDAKernel : public framework::OpKernel { auto *pos_in_data = stat_pos_in_tensor->data(); auto *stat_neg_in_tensor = ctx.Input("StatNeg"); auto *neg_in_data = stat_neg_in_tensor->data(); +#ifdef PADDLE_WITH_CUDA if (stat_pos_in_tensor != stat_pos) { cudaMemcpy(origin_stat_pos, pos_in_data, ((1 + slide_steps) * (num_thresholds + 1) + @@ -144,6 +145,22 @@ class AucCUDAKernel : public framework::OpKernel { sizeof(int64_t), cudaMemcpyDeviceToDevice); } +#else + if (stat_pos_in_tensor != stat_pos) { + hipMemcpy(origin_stat_pos, pos_in_data, + ((1 + slide_steps) * (num_thresholds + 1) + + (slide_steps > 0 ? 1 : 0)) * + sizeof(int64_t), + hipMemcpyDeviceToDevice); + } + if (stat_neg_in_tensor != stat_neg) { + hipMemcpy(origin_stat_neg, neg_in_data, + ((1 + slide_steps) * (num_thresholds + 1) + + (slide_steps > 0 ? 1 : 0)) * + sizeof(int64_t), + hipMemcpyDeviceToDevice); + } +#endif statAuc(ctx, label, predict, num_thresholds, slide_steps, origin_stat_pos, origin_stat_neg);