diff --git a/cmake/inference_lib.cmake b/cmake/inference_lib.cmake index d0a055d0f2e64..9f1268ce36c41 100755 --- a/cmake/inference_lib.cmake +++ b/cmake/inference_lib.cmake @@ -335,6 +335,12 @@ copy( DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/ ) +copy( + inference_lib_dist + SRCS ${PADDLE_SOURCE_DIR}/paddle/phi/core/distributed/auto_parallel/*.h + DSTS ${PADDLE_INFERENCE_INSTALL_DIR}/paddle/include/paddle/phi/core/distributed/auto_parallel/ +) + copy( inference_lib_dist SRCS ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/init_phi.h diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index 8894a06267b51..b28357672c046 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -453,14 +453,176 @@ paddle::Tensor BuildEmptyDistPaddleTensor( #endif #ifdef PADDLE_WITH_DISTRIBUTE -std::tuple PrepareCtxForAutoParallel( + +phi::distributed::SpmdInfo RunInferSpmdFn( + const paddle::OpMetaInfo& op_info, + const std::vector& inputs, + const std::vector& outputs, + paddle::CustomOpKernelContext& ctx) { // NOLINT + auto& infer_spmd_func = paddle::OpMetaInfoHelper::GetInferSpmdFn(op_info); + if (infer_spmd_func == nullptr) { + // default rule + std::vector meta_dist_inputs; + auto all_inputs = ctx.AllMutableInput(); + for (auto& t : *all_inputs) { + phi::distributed::DistMetaTensor meta_dist_input; + if (t.impl().get()) { + meta_dist_input = + paddle::experimental::MakeDistMetaTensor(*(t.impl().get())); + } + meta_dist_inputs.push_back(meta_dist_input); + } + auto spmd_info_tmp = + phi::distributed::VariadicReplicatedInferSpmdDynamic(meta_dist_inputs); + // flatten input + phi::distributed::SpmdInfo spmd_info; + auto dist_attrs = PADDLE_GET(std::vector, + spmd_info_tmp.first[0]); + for (auto& e : dist_attrs) { + spmd_info.first.push_back(std::move(e)); + } + return spmd_info; + } + std::vector tensor_inputs; + size_t input_size = inputs.size(); + for (size_t i = 0; i < input_size; ++i) { + const auto& in_name = inputs[i]; + if (paddle::framework::detail::IsDuplicableVar(in_name)) { + std::vector meta_tensors; + auto& range = ctx.InputRangeAt(i); + for (size_t j = range.first; j < range.second; ++j) { + auto& t = ctx.InputAt(j); + phi::distributed::DistMetaTensor meta_tensor; + if (t.impl().get()) { + meta_tensor = + paddle::experimental::MakeDistMetaTensor(*(t.impl().get())); + } + meta_tensors.emplace_back(std::move(meta_tensor)); + } + tensor_inputs.emplace_back(std::move(meta_tensors)); + } else { + auto& range = ctx.InputRangeAt(i); + auto& t = ctx.InputAt(range.first); + phi::distributed::DistMetaTensor meta_tensor; + if (t.impl().get()) { + meta_tensor = + paddle::experimental::MakeDistMetaTensor(*(t.impl().get())); + } + tensor_inputs.emplace_back(std::move(meta_tensor)); + } + } + const std::vector& attrs = ctx.Attrs(); + auto spmd_info_tmp = infer_spmd_func(tensor_inputs, attrs); + // flatten input + phi::distributed::SpmdInfo spmd_info; + for (auto& e : spmd_info_tmp.first) { + if (paddle::holds_alternative(e)) { + spmd_info.first.push_back( + std::move(PADDLE_GET(phi::distributed::TensorDistAttr, e))); + } else { + for (auto& ee : + PADDLE_GET(std::vector, e)) { + spmd_info.first.push_back(std::move(ee)); + } + } + } + + // flatten output + for (auto& e : spmd_info_tmp.second) { + if (paddle::holds_alternative(e)) { + spmd_info.second.push_back( + std::move(PADDLE_GET(phi::distributed::TensorDistAttr, e))); + } else { + for (auto& ee : + PADDLE_GET(std::vector, e)) { + spmd_info.second.push_back(std::move(ee)); + } + } + } + + return spmd_info; +} + +std::vector> RunInferShapeFn( const paddle::OpMetaInfo& op_info, bool is_forward, bool is_double_grad, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map, paddle::CustomOpKernelContext& ctx) { // NOLINT + auto& infer_shape_func = paddle::OpMetaInfoHelper::GetInferShapeFn(op_info); + + std::vector> out_dims; + if (infer_shape_func) { + out_dims = + RunInferShapeFunc(ctx, infer_shape_func, inputs, outputs, inplace_map); + } else { + if (is_forward) { + out_dims = RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); + } else { + out_dims = + RunDefaultGradInferShapeFunc(ctx, inputs, outputs, is_double_grad); + } + } + + PADDLE_ENFORCE_EQ( + out_dims.size(), + ctx.OutputRange().size(), + phi::errors::InvalidArgument( + "Custome op infer_shape return size should be %d, but got %d.", + ctx.OutputRange().size(), + out_dims.size())); + + return out_dims; +} + +std::vector> RunInferDtypeFn( + const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map, + paddle::CustomOpKernelContext& ctx) { // NOLINT + + auto& infer_dtype_func = paddle::OpMetaInfoHelper::GetInferDtypeFn(op_info); + std::vector> out_dtypes; + if (infer_dtype_func) { + out_dtypes = + RunInferDtypeFunc(ctx, infer_dtype_func, inputs, outputs, inplace_map); + } else { + if (is_forward) { + out_dtypes = RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); + } else { + out_dtypes = + RunDefaultGradInferDtypeFunc(ctx, inputs, outputs, is_double_grad); + } + } + PADDLE_ENFORCE_EQ( + out_dtypes.size(), + ctx.OutputRange().size(), + phi::errors::InvalidArgument( + "Custome op infer_dtype return size should be %d, but got %d.", + ctx.OutputRange().size(), + out_dtypes.size())); + return out_dtypes; +} + +std:: + tuple + PrepareCtxForAutoParallel( + const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + paddle::CustomOpKernelContext& ctx, // NOLINT + std::vector>& + dist_inputs, // NOLINT + std::vector& output_dims) { // NOLINT bool run_auto_parallel = false; bool rank_is_in_current_mesh = true; phi::distributed::ProcessMesh current_process_mesh; + phi::distributed::SpmdInfo spmd_info; const auto& inputs = paddle::OpMetaInfoHelper::GetInputs(op_info); const auto& outputs = paddle::OpMetaInfoHelper::GetOutputs(op_info); @@ -483,115 +645,73 @@ std::tuple PrepareCtxForAutoParallel( .process_mesh(); rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); - std::vector input_x(x.size()); - for (size_t i = 0; i < input_x.size(); ++i) { - input_x[i] = x.at(i).impl().get(); - } + spmd_info = RunInferSpmdFn(op_info, inputs, outputs, ctx); - auto meta_dist_input_x = paddle::experimental::MakeDistMetaTensor(input_x); - auto spmd_info = - phi::distributed::VariadicReplicatedInferSpmdDynamic(meta_dist_input_x); current_process_mesh = paddle::holds_alternative( spmd_info.first[0]) ? paddle::get<0>(spmd_info.first[0]).process_mesh() : paddle::get<1>(spmd_info.first[0]).at(0).process_mesh(); + std::vector> out_dims = RunInferShapeFn( + op_info, is_forward, is_double_grad, inputs, outputs, inplace_map, ctx); + + std::vector> out_dtypes = RunInferDtypeFn( + op_info, is_forward, is_double_grad, inputs, outputs, inplace_map, ctx); + if (rank_is_in_current_mesh) { auto* dev_ctx = phi::DeviceContextPool::Instance().Get(x.at(0).place()); - auto dist_input_x = paddle::experimental::ReshardApiInputToKernelInput( - dev_ctx, x, spmd_info.first[0]); for (size_t i = 0; i < x.size(); ++i) { + auto dist_input_i = paddle::experimental::ReshardApiInputToKernelInput( + dev_ctx, x[i], spmd_info.first[i]); all_inputs->at(i).set_impl( - std::make_shared(dist_input_x[i]->value())); - } - } else { - auto& infer_shape_func = - paddle::OpMetaInfoHelper::GetInferShapeFn(op_info); - auto& infer_dtype_func = - paddle::OpMetaInfoHelper::GetInferDtypeFn(op_info); - - std::vector> out_dims; - if (infer_shape_func) { - out_dims = RunInferShapeFunc( - ctx, infer_shape_func, inputs, outputs, inplace_map); - } else { - if (is_forward) { - out_dims = - RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); - } else { - out_dims = RunDefaultGradInferShapeFunc( - ctx, inputs, outputs, is_double_grad); - } - } - - std::vector> out_dtypes; - if (infer_dtype_func) { - out_dtypes = RunInferDtypeFunc( - ctx, infer_dtype_func, inputs, outputs, inplace_map); - } else { - if (is_forward) { - out_dtypes = - RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); - } else { - out_dtypes = RunDefaultGradInferDtypeFunc( - ctx, inputs, outputs, is_double_grad); - } + std::make_shared(dist_input_i->value())); + dist_inputs.emplace_back(dist_input_i); } + } + for (size_t i = 0; i < out_dims.size(); ++i) { + const auto& out_dim = out_dims.at(i); + const auto& out_dtype = out_dtypes.at(i); + const auto& pair = ctx.OutputRangeAt(i); PADDLE_ENFORCE_EQ( - out_dims.size(), - ctx.OutputRange().size(), - phi::errors::InvalidArgument( - "Custome op infer_shape return size should be %d, but got %d.", - ctx.OutputRange().size(), - out_dims.size())); - + out_dim.size(), + pair.second - pair.first, + phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " + "size should be %d, but got %d.", + i, + pair.second - pair.first, + out_dim.size())); PADDLE_ENFORCE_EQ( - out_dtypes.size(), - ctx.OutputRange().size(), - phi::errors::InvalidArgument( - "Custome op infer_dtype return size should be %d, but got %d.", - ctx.OutputRange().size(), - out_dtypes.size())); - - for (size_t i = 0; i < out_dims.size(); ++i) { - const auto& out_dim = out_dims.at(i); - const auto& out_dtype = out_dtypes.at(i); - const auto& pair = ctx.OutputRangeAt(i); - PADDLE_ENFORCE_EQ( - out_dim.size(), - pair.second - pair.first, - phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " - "size should be %d, but got %d.", - i, - pair.second - pair.first, - out_dim.size())); - PADDLE_ENFORCE_EQ( - out_dtype.size(), - pair.second - pair.first, - phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " - "size should be %d, but got %d.", - i, - pair.second - pair.first, - out_dtype.size())); - - if (out_dim.size() == 1) { + out_dtype.size(), + pair.second - pair.first, + phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " + "size should be %d, but got %d.", + i, + pair.second - pair.first, + out_dtype.size())); + + if (out_dim.size() == 1) { + output_dims.emplace_back(out_dim[0]); + if (!rank_is_in_current_mesh) { *(ctx.MutableOutputAt(pair.first)) = BuildEmptyDistPaddleTensor( current_process_mesh, out_dim[0], out_dtype[0]); - } else { - for (size_t j = pair.first; j < pair.second; j++) { + } + } else { + for (size_t j = pair.first; j < pair.second; j++) { + output_dims.emplace_back(out_dim[j]); + if (!rank_is_in_current_mesh) { *(ctx.MutableOutputAt(j)) = BuildEmptyDistPaddleTensor( current_process_mesh, out_dim[j], out_dtype[j]); } } } - return std::tuple( - run_auto_parallel, rank_is_in_current_mesh, current_process_mesh); } } - return std::tuple( - run_auto_parallel, rank_is_in_current_mesh, current_process_mesh); + return {run_auto_parallel, + rank_is_in_current_mesh, + current_process_mesh, + spmd_info}; } #endif @@ -599,27 +719,48 @@ std::tuple PrepareCtxForAutoParallel( void TransCtxTensorsToDistTensors( paddle::CustomOpKernelContext& ctx, // NOLINT bool run_auto_parallel, - const phi::distributed::ProcessMesh& current_process_mesh) { + const phi::distributed::ProcessMesh& current_process_mesh, + const phi::distributed::SpmdInfo& spmd_info, + std::vector>& + dist_inputs, // NOLINT + std::vector& output_dims) { // NOLINT if (run_auto_parallel) { std::vector* output_all = ctx.AllMutableOutput(); for (size_t i = 0; i < output_all->size(); ++i) { auto& tensor = output_all->at(i); - phi::distributed::TensorDistAttr dist_attr = - phi::distributed::TensorDistAttr(common::vectorize(tensor.dims())); - dist_attr.set_process_mesh(current_process_mesh); + phi::distributed::TensorDistAttr dist_attr; + if (!spmd_info.second.empty()) { + dist_attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, + spmd_info.second[i]); + } else { + std::vector shape = common::vectorize(output_dims[i]); + dist_attr.set_default_dims_mapping(shape); + dist_attr.set_process_mesh(current_process_mesh); + } auto dist_t = std::make_shared( std::dynamic_pointer_cast(tensor.impl()), + output_dims[i], dist_attr); tensor.set_impl(dist_t); } std::vector* input_all = ctx.AllMutableInput(); for (size_t i = 0; i < input_all->size(); ++i) { auto& tensor = input_all->at(i); - phi::distributed::TensorDistAttr dist_attr = - phi::distributed::TensorDistAttr(common::vectorize(tensor.dims())); - dist_attr.set_process_mesh(current_process_mesh); + phi::distributed::TensorDistAttr dist_attr; + phi::DDim global_dims; + + if (i < dist_inputs.size()) { + auto& dist_input = dist_inputs.at(i); + global_dims = dist_input->dims(); + dist_attr = dist_input->dist_attr(); + } else { + dist_attr = PADDLE_GET_CONST(phi::distributed::TensorDistAttr, + spmd_info.first[i]); + global_dims = tensor.dims(); + } auto dist_t = std::make_shared( std::dynamic_pointer_cast(tensor.impl()), + global_dims, dist_attr); tensor.set_impl(dist_t); } @@ -637,11 +778,15 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info, ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); #ifdef PADDLE_WITH_DISTRIBUTE - auto result = - PrepareCtxForAutoParallel(op_info, is_forward, is_double_grad, ctx); + // for output + std::vector> dist_inputs; + std::vector output_dims; + auto result = PrepareCtxForAutoParallel( + op_info, is_forward, is_double_grad, ctx, dist_inputs, output_dims); bool run_auto_parallel = std::get<0>(result); bool rank_is_in_current_mesh = std::get<1>(result); phi::distributed::ProcessMesh current_process_mesh = std::get<2>(result); + auto& spmd_info = std::get<3>(result); if (!rank_is_in_current_mesh) { return; } @@ -667,7 +812,12 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info, ctx.AssignInplaceOutputs(); #ifdef PADDLE_WITH_DISTRIBUTE - TransCtxTensorsToDistTensors(ctx, run_auto_parallel, current_process_mesh); + TransCtxTensorsToDistTensors(ctx, + run_auto_parallel, + current_process_mesh, + spmd_info, + dist_inputs, + output_dims); #endif } diff --git a/paddle/phi/core/distributed/auto_parallel/device_mesh.h b/paddle/phi/core/distributed/auto_parallel/device_mesh.h index 8cfdc6ed242f0..0741e03fe94c0 100644 --- a/paddle/phi/core/distributed/auto_parallel/device_mesh.h +++ b/paddle/phi/core/distributed/auto_parallel/device_mesh.h @@ -23,7 +23,6 @@ limitations under the License. */ #include #include -#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" #include "paddle/phi/core/enforce.h" diff --git a/paddle/phi/core/distributed/auto_parallel/dist_attr.h b/paddle/phi/core/distributed/auto_parallel/dist_attr.h index d158fc848c8d4..e4016b9f65cdc 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_attr.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_attr.h @@ -18,6 +18,7 @@ limitations under the License. */ #include #include #include +#include #include #include diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index c41effe6c8522..885797b7386e4 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -162,6 +162,21 @@ DistTensor::DistTensor(const std::shared_ptr& local_value, } } +DistTensor::DistTensor(const std::shared_ptr& local_value, + const DDim& global_dims, + const TensorDistAttr& dist_attr) + : global_dims_(global_dims), dist_attr_(dist_attr) { + process_mesh_ = dist_attr_.process_mesh(); + placements_ = ToPlacements(dist_attr); + if (IsCurRankInMesh(process_mesh_)) { + value_ = local_value; + } else { + value_ = std::make_shared( + std::make_shared(nullptr, 0, local_value->place()), + phi::DenseTensorMeta(local_value->dtype(), global_dims_)); + } +} + DistTensor::DistTensor(const std::shared_ptr& global_value, const ProcessMesh& process_mesh, const Placements& placements) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h index 55b35ffe5c25a..5ad10c76b2508 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h @@ -58,6 +58,13 @@ class DistTensor final const ProcessMesh& process_mesh, const Placements& placements); + /// \brief Construct a dist tensor based local dense tensor. + /// \param global_dims The global dim of the dist tensor. + /// \param dist_attr The distributed attributes of the current tensor. + DistTensor(const std::shared_ptr& local_value, + const DDim& global_dims, + const TensorDistAttr& dist_attr); + /// \brief Construct a dist tensor based local dense tensor. /// \param global_dims The global dim of the dist tensor. /// \param process_mesh The process mesh of the current tensor. diff --git a/paddle/phi/core/distributed/auto_parallel/process_mesh.cc b/paddle/phi/core/distributed/auto_parallel/process_mesh.cc index a1b60e27c27e6..983725880f352 100644 --- a/paddle/phi/core/distributed/auto_parallel/process_mesh.cc +++ b/paddle/phi/core/distributed/auto_parallel/process_mesh.cc @@ -16,6 +16,8 @@ limitations under the License. */ #include #include + +#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" namespace phi { diff --git a/paddle/phi/core/distributed/auto_parallel/process_mesh.h b/paddle/phi/core/distributed/auto_parallel/process_mesh.h index 792d5e38f5318..1b76dec23c2a0 100644 --- a/paddle/phi/core/distributed/auto_parallel/process_mesh.h +++ b/paddle/phi/core/distributed/auto_parallel/process_mesh.h @@ -20,7 +20,6 @@ limitations under the License. */ #include #include -#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h" #include "paddle/phi/core/distributed/auto_parallel/device_mesh.h" #include "paddle/phi/core/distributed/auto_parallel/utils.h" #include "paddle/phi/core/enforce.h" diff --git a/test/auto_parallel/CMakeLists.txt b/test/auto_parallel/CMakeLists.txt index a735762cce658..ab2b09680c5ad 100644 --- a/test/auto_parallel/CMakeLists.txt +++ b/test/auto_parallel/CMakeLists.txt @@ -3,6 +3,7 @@ add_subdirectory(spmd_rules) add_subdirectory(hybrid_strategy) +add_subdirectory(custom_op) if(WITH_DISTRIBUTE AND WITH_GPU) diff --git a/test/auto_parallel/custom_op/CMakeLists.txt b/test/auto_parallel/custom_op/CMakeLists.txt new file mode 100644 index 0000000000000..b3537bc09c4e0 --- /dev/null +++ b/test/auto_parallel/custom_op/CMakeLists.txt @@ -0,0 +1,16 @@ +set(LOCAL_ALL_ARCH ON) +set(LOCAL_ALL_PLAT ON) +if(WITH_DISTRIBUTE + AND WITH_GPU + AND (LINUX)) + py_test_modules( + test_semi_auto_parallel_custom_op + MODULES + test_semi_auto_parallel_custom_op + ENVS + "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python;PADDLE_SOURCE_DIR=${PROJECT_SOURCE_DIR};WITH_MKLDNN=${WITH_MKLDNN};MKLDNN_INSTALL_DIR=${MKLDNN_INSTALL_DIR};WITH_MKLDNN=${WITH_MKLDNN};WITH_GPU=${WITH_GPU};WITH_ROCM=${WITH_ROCM};externalError_INCLUDE_DIR=${externalError_INCLUDE_DIR};PYBIND_INCLUDE_DIR=${PYBIND_INCLUDE_DIR}" + ) + set_tests_properties(test_semi_auto_parallel_custom_op + PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 120) + +endif() diff --git a/test/auto_parallel/custom_op/custom_relu_op.cc b/test/auto_parallel/custom_op/custom_relu_op.cc new file mode 100644 index 0000000000000..7f76ab92cb2d1 --- /dev/null +++ b/test/auto_parallel/custom_op/custom_relu_op.cc @@ -0,0 +1,134 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" +#include "paddle/phi/api/ext/spmd_infer.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" + +#define CHECK_CPU_INPUT(x) PD_CHECK(x.is_cpu(), #x " must be a CPU Tensor.") + +template +void relu_cpu_forward_kernel(const data_t* x_data, + data_t* out_data, + int64_t x_numel) { + PD_CHECK(x_data != nullptr, "x_data is nullptr."); + PD_CHECK(out_data != nullptr, "out_data is nullptr."); + for (int64_t i = 0; i < x_numel; ++i) { + out_data[i] = std::max(static_cast(0.), x_data[i]); + } +} + +template +void relu_cpu_backward_kernel(const data_t* grad_out_data, + const data_t* out_data, + data_t* grad_x_data, + int64_t out_numel) { + for (int64_t i = 0; i < out_numel; ++i) { + grad_x_data[i] = + grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); + } +} + +std::vector relu_cpu_forward(const paddle::Tensor& x) { + auto out = paddle::empty_like(x); + + PD_DISPATCH_FLOATING_TYPES( + x.type(), "relu_cpu_forward", ([&] { + relu_cpu_forward_kernel( + x.data(), out.data(), x.numel()); + })); + + return {out}; +} + +std::vector relu_cpu_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + auto grad_x = paddle::empty_like(x); + + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { + relu_cpu_backward_kernel( + grad_out.data(), + out.data(), + grad_x.data(), + out.size()); + })); + + return {grad_x}; +} + +std::vector relu_cuda_forward(const paddle::Tensor& x); +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out); + +std::vector ReluForward(const paddle::Tensor& x) { + if (x.is_cpu()) { + return relu_cpu_forward(x); + } else if (x.is_gpu()) { + return relu_cuda_forward(x); + } else { + PD_THROW("Not implemented."); + } +} + +std::vector ReluBackward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + if (x.is_cpu()) { + return relu_cpu_backward(x, out, grad_out); + } else if (x.is_gpu()) { + return relu_cuda_backward(x, out, grad_out); + } else { + PD_THROW("Not implemented."); + } +} + +phi::distributed::SpmdInfo ReluGradInferSpmd( + const phi::distributed::DistMetaTensor& x, + const phi::distributed::DistMetaTensor& out, + const phi::distributed::DistMetaTensor& out_grad) { + return phi::distributed::ElementwiseUnaryGradInferSpmd(x, out, out_grad); +} + +PD_BUILD_OP(custom_relu) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForward)) + .SetInferSpmdFn( + PD_INFER_SPMD_RULE(phi::distributed::ElementwiseUnaryInferSpmd)); + +PD_BUILD_GRAD_OP(custom_relu) + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackward)) + .SetInferSpmdFn(PD_INFER_SPMD_RULE(ReluGradInferSpmd)); + +PD_BUILD_OP(custom_relu_no_spmd) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForward)); + +PD_BUILD_GRAD_OP(custom_relu_no_spmd) + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackward)); + +PD_REGISTER_SPMD_RULE( + custom_relu, + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd), + PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmdReverse)); diff --git a/test/auto_parallel/custom_op/custom_relu_op.cu b/test/auto_parallel/custom_op/custom_relu_op.cu new file mode 100644 index 0000000000000..810ff75be5578 --- /dev/null +++ b/test/auto_parallel/custom_op/custom_relu_op.cu @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" + +#define CHECK_GPU_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") + +template +__global__ void relu_cuda_forward_kernel(const data_t* x, + data_t* y, + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { + y[i] = x[i] > static_cast(0.) ? x[i] : static_cast(0.); + } +} + +template +__global__ void relu_cuda_backward_kernel(const data_t* dy, + const data_t* y, + data_t* dx, + int64_t num) { + int64_t gid = blockIdx.x * blockDim.x + threadIdx.x; + for (int64_t i = gid; i < num; i += blockDim.x * gridDim.x) { + dx[i] = dy[i] * (y[i] > static_cast(0.) ? static_cast(1.) + : static_cast(0.)); + } +} + +std::vector relu_cuda_forward(const paddle::Tensor& x) { + CHECK_GPU_INPUT(x); + auto out = paddle::empty_like(x); + + PD_CHECK(x.place() == paddle::DefaultGPUPlace()); + + int64_t numel = x.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + x.type(), "relu_cuda_forward_kernel", ([&] { + relu_cuda_forward_kernel<<>>( + x.data(), out.data(), numel); + })); + + return {out}; +} + +std::vector relu_cuda_backward(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out) { + CHECK_GPU_INPUT(x); + CHECK_GPU_INPUT(out); + CHECK_GPU_INPUT(grad_out); + auto grad_x = paddle::empty_like(x); + + PD_CHECK(x.place() == paddle::DefaultGPUPlace()); + + int64_t numel = out.numel(); + int64_t block = 512; + int64_t grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + out.type(), "relu_cuda_backward_kernel", ([&] { + relu_cuda_backward_kernel<<>>( + grad_out.data(), + out.data(), + grad_x.mutable_data(x.place()), + numel); + })); + + return {grad_x}; +} diff --git a/test/auto_parallel/custom_op/custom_relu_setup.py b/test/auto_parallel/custom_op/custom_relu_setup.py new file mode 100644 index 0000000000000..567e7ac65d1e3 --- /dev/null +++ b/test/auto_parallel/custom_op/custom_relu_setup.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from utils import extra_compile_args, paddle_includes + +from paddle.utils.cpp_extension import CUDAExtension, setup + +# Mac-CI don't support GPU +Extension = CUDAExtension +sources = ['custom_relu_op.cc', 'custom_relu_op.cu'] + +setup( + name='custom_relu', + ext_modules=Extension( + sources=sources, + include_dirs=paddle_includes, + extra_compile_args=extra_compile_args, + verbose=True, + ), +) diff --git a/test/auto_parallel/custom_op/semi_auto_parallel_for_custom_op.py b/test/auto_parallel/custom_op/semi_auto_parallel_for_custom_op.py new file mode 100644 index 0000000000000..1de800abe1b99 --- /dev/null +++ b/test/auto_parallel/custom_op/semi_auto_parallel_for_custom_op.py @@ -0,0 +1,89 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(os.path.abspath(os.path.join(__dir__, '..'))) + +from semi_auto_parallel_util import SemiAutoParallelTestBase + +import paddle +import paddle.distributed as dist +from paddle.framework import core + +import custom_relu # noqa: F401 # pylint: disable=unused-import # isort:skip + +assert core.contains_spmd_rule("custom_relu") + + +class TestCusomOpSemiAutoParallel(SemiAutoParallelTestBase): + def __init__(self): + super().__init__() + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + + def check_placements(self, output, expected_placements): + assert ( + output.placements == expected_placements + ), f"{output.placements} vs {expected_placements}" + + def test_custom_relu(self): + shapes = [16, 4, 4] + specs = ['x', None, None] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=custom_relu.custom_relu, + with_backward=True, + ) + self.check_placements(outputs, [dist.Shard(0)]) + + def test_custom_relu_no_spmd(self): + shapes = [16, 4, 4] + specs = ['x', None, None] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=custom_relu.custom_relu_no_spmd, + with_backward=True, + ) + self.check_placements(outputs, [dist.Replicate()]) + + def test_custom_relu_no_shard(self): + shapes = [16, 4, 4] + specs = [None, None, None] + inputs, outputs = self.runfunc_and_check( + inputs_shape=shapes, + inputs_specs=specs, + op_func=custom_relu.custom_relu, + with_backward=True, + ) + self.check_placements(outputs, [dist.Replicate()]) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + self.test_custom_relu_no_shard() + self.test_custom_relu() + self.test_custom_relu_no_spmd() + + +if __name__ == '__main__': + TestCusomOpSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/custom_op/test_semi_auto_parallel_custom_op.py b/test/auto_parallel/custom_op/test_semi_auto_parallel_custom_op.py new file mode 100644 index 0000000000000..a8014a81c2548 --- /dev/null +++ b/test/auto_parallel/custom_op/test_semi_auto_parallel_custom_op.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import unittest + +import collective.test_communication_api_base as test_base + +from paddle.utils.cpp_extension.extension_utils import run_cmd + + +class TestCusomOp(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=2, timeout=200, nnode=1) + self._default_envs = {"dtype": "float32", "seed": "2023"} + self._changeable_envs = {"backend": ["cpu", "gpu"]} + cur_dir = os.path.dirname(os.path.abspath(__file__)) + # compile, install the custom op egg into site-packages under background + if os.name == 'nt': + cmd = f'cd /d {cur_dir} && python custom_relu_setup.py install' + else: + cmd = ( + f'cd {cur_dir} && {sys.executable} custom_relu_setup.py install' + ) + run_cmd(cmd) + + # test dynamic auto parallel run + def test_dynamic_auto_parallel(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_custom_op.py", + user_defined_envs=envs, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/custom_op/utils.py b/test/auto_parallel/custom_op/utils.py new file mode 100644 index 0000000000000..07f08648c0a62 --- /dev/null +++ b/test/auto_parallel/custom_op/utils.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from site import getsitepackages + +# Test for extra compile args +extra_cc_args = ['-w', '-g'] +extra_nvcc_args = ['-O3'] +extra_compile_args = {'cc': extra_cc_args, 'nvcc': extra_nvcc_args} + + +def get_paddle_includes(): + env_dict = os.environ + paddle_includes = [] + paddle_includes.append(f"{env_dict.get('PADDLE_SOURCE_DIR')}") + + # mkldnn + if env_dict.get("WITH_MKLDNN") == 'ON': + paddle_includes.append(f"{env_dict.get('MKLDNN_INSTALL_DIR')}/include") + if env_dict.get("WITH_GPU") == 'ON' or env_dict.get("WITH_ROCM") == 'ON': + paddle_includes.append(f"{env_dict.get('externalError_INCLUDE_DIR')}") + paddle_includes.append(f"{env_dict.get('PYBIND_INCLUDE_DIR')}") + + for site_packages_path in getsitepackages(): + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include') + ) + paddle_includes.append( + os.path.join(site_packages_path, 'paddle', 'include', 'third_party') + ) + + return paddle_includes + + +paddle_includes = get_paddle_includes()