From 212ec21ce772ba4414d728f389f31ad6b2476044 Mon Sep 17 00:00:00 2001 From: liuzhenhai93 Date: Fri, 5 Jan 2024 09:53:51 +0000 Subject: [PATCH] suport reshard --- .../custom_operator/custom_operator_utils.cc | 197 +++++++++++------- .../distributed/auto_parallel/dist_tensor.cc | 15 ++ .../distributed/auto_parallel/dist_tensor.h | 7 + 3 files changed, 142 insertions(+), 77 deletions(-) diff --git a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc index 80c2c5b4c7be6f..abe90ea23c2264 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_utils.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_utils.cc @@ -543,12 +543,80 @@ phi::distributed::SpmdInfo RunInferSpmdFn( return spmd_info; } +std::vector> RunInferShapeFn( + const paddle::OpMetaInfo& op_info, + bool is_forward, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map, + paddle::CustomOpKernelContext& ctx) { // NOLINT + auto& infer_shape_func = paddle::OpMetaInfoHelper::GetInferShapeFn(op_info); + + std::vector> out_dims; + if (infer_shape_func) { + out_dims = + RunInferShapeFunc(ctx, infer_shape_func, inputs, outputs, inplace_map); + } else { + if (is_forward) { + out_dims = RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); + } else { + out_dims = + RunDefaultGradInferShapeFunc(ctx, inputs, outputs, is_double_grad); + } + } + + PADDLE_ENFORCE_EQ( + out_dims.size(), + ctx.OutputRange().size(), + phi::errors::InvalidArgument( + "Custome op infer_shape return size should be %d, but got %d.", + ctx.OutputRange().size(), + out_dims.size())); + + return out_dims; +} + +std::vector> RunInferDtypeFn( + const paddle::OpMetaInfo& op_info, + bool is_forward, + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map, + paddle::CustomOpKernelContext& ctx) { // NOLINT + + auto& infer_dtype_func = paddle::OpMetaInfoHelper::GetInferDtypeFn(op_info); + std::vector> out_dtypes; + if (infer_dtype_func) { + out_dtypes = + RunInferDtypeFunc(ctx, infer_dtype_func, inputs, outputs, inplace_map); + } else { + if (is_forward) { + out_dtypes = RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); + } else { + out_dtypes = + RunDefaultGradInferDtypeFunc(ctx, inputs, outputs, is_double_grad); + } + } + PADDLE_ENFORCE_EQ( + out_dtypes.size(), + ctx.OutputRange().size(), + phi::errors::InvalidArgument( + "Custome op infer_dtype return size should be %d, but got %d.", + ctx.OutputRange().size(), + out_dtypes.size())); + return out_dtypes; +} + std:: tuple - PrepareCtxForAutoParallel(const paddle::OpMetaInfo& op_info, - bool is_forward, - bool is_double_grad, - paddle::CustomOpKernelContext& ctx) { // NOLINT + PrepareCtxForAutoParallel( + const paddle::OpMetaInfo& op_info, + bool is_forward, + bool is_double_grad, + paddle::CustomOpKernelContext& ctx, // NOLINT + std::vector>& + dist_inputs, // NOLINT + std::vector& output_dims) { // NOLINT bool run_auto_parallel = false; bool rank_is_in_current_mesh = true; phi::distributed::ProcessMesh current_process_mesh; @@ -590,83 +658,47 @@ std:: dev_ctx, x[i], spmd_info.first[i]); all_inputs->at(i).set_impl( std::make_shared(dist_input_i->value())); + dist_inputs.emplace_back(dist_input_i); } - } else { - auto& infer_shape_func = - paddle::OpMetaInfoHelper::GetInferShapeFn(op_info); - auto& infer_dtype_func = - paddle::OpMetaInfoHelper::GetInferDtypeFn(op_info); - - std::vector> out_dims; - if (infer_shape_func) { - out_dims = RunInferShapeFunc( - ctx, infer_shape_func, inputs, outputs, inplace_map); - } else { - if (is_forward) { - out_dims = - RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map); - } else { - out_dims = RunDefaultGradInferShapeFunc( - ctx, inputs, outputs, is_double_grad); - } - } + } - std::vector> out_dtypes; - if (infer_dtype_func) { - out_dtypes = RunInferDtypeFunc( - ctx, infer_dtype_func, inputs, outputs, inplace_map); - } else { - if (is_forward) { - out_dtypes = - RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map); - } else { - out_dtypes = RunDefaultGradInferDtypeFunc( - ctx, inputs, outputs, is_double_grad); - } - } + std::vector> out_dims = + RunInferShapeFn(op_info, is_forward, inputs, outputs, inplace_map, ctx); - PADDLE_ENFORCE_EQ( - out_dims.size(), - ctx.OutputRange().size(), - phi::errors::InvalidArgument( - "Custome op infer_shape return size should be %d, but got %d.", - ctx.OutputRange().size(), - out_dims.size())); + std::vector> out_dtypes = + RunInferDtypeFn(op_info, is_forward, inputs, outputs, inplace_map, ctx); + for (size_t i = 0; i < out_dims.size(); ++i) { + const auto& out_dim = out_dims.at(i); + const auto& out_dtype = out_dtypes.at(i); + const auto& pair = ctx.OutputRangeAt(i); PADDLE_ENFORCE_EQ( - out_dtypes.size(), - ctx.OutputRange().size(), - phi::errors::InvalidArgument( - "Custome op infer_dtype return size should be %d, but got %d.", - ctx.OutputRange().size(), - out_dtypes.size())); - - for (size_t i = 0; i < out_dims.size(); ++i) { - const auto& out_dim = out_dims.at(i); - const auto& out_dtype = out_dtypes.at(i); - const auto& pair = ctx.OutputRangeAt(i); - PADDLE_ENFORCE_EQ( - out_dim.size(), - pair.second - pair.first, - phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " - "size should be %d, but got %d.", - i, - pair.second - pair.first, - out_dim.size())); - PADDLE_ENFORCE_EQ( - out_dtype.size(), - pair.second - pair.first, - phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " - "size should be %d, but got %d.", - i, - pair.second - pair.first, - out_dtype.size())); - - if (out_dim.size() == 1) { + out_dim.size(), + pair.second - pair.first, + phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " + "size should be %d, but got %d.", + i, + pair.second - pair.first, + out_dim.size())); + PADDLE_ENFORCE_EQ( + out_dtype.size(), + pair.second - pair.first, + phi::errors::InvalidArgument("custome op infer_shape result[%d]'s " + "size should be %d, but got %d.", + i, + pair.second - pair.first, + out_dtype.size())); + + if (out_dim.size() == 1) { + output_dims.emplace_back(out_dim[0]); + if (!rank_is_in_current_mesh) { *(ctx.MutableOutputAt(pair.first)) = BuildEmptyDistPaddleTensor( current_process_mesh, out_dim[0], out_dtype[0]); - } else { - for (size_t j = pair.first; j < pair.second; j++) { + } + } else { + for (size_t j = pair.first; j < pair.second; j++) { + output_dims.emplace_back(out_dim[j]); + if (!rank_is_in_current_mesh) { *(ctx.MutableOutputAt(j)) = BuildEmptyDistPaddleTensor( current_process_mesh, out_dim[j], out_dtype[j]); } @@ -686,7 +718,10 @@ void TransCtxTensorsToDistTensors( paddle::CustomOpKernelContext& ctx, // NOLINT bool run_auto_parallel, const phi::distributed::ProcessMesh& current_process_mesh, - const phi::distributed::SpmdInfo& spmd_info) { + const phi::distributed::SpmdInfo& spmd_info, + std::vector>& + dist_inputs, // NOLINT + std::vector& output_dims) { // NOLINT if (run_auto_parallel) { std::vector* output_all = ctx.AllMutableOutput(); for (size_t i = 0; i < output_all->size(); ++i) { @@ -701,6 +736,7 @@ void TransCtxTensorsToDistTensors( } auto dist_t = std::make_shared( std::dynamic_pointer_cast(tensor.impl()), + output_dims[i], dist_attr); tensor.set_impl(dist_t); } @@ -735,6 +771,9 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info, ctx.ConstructInplaceIndex(inputs, outputs, inplace_map); #ifdef PADDLE_WITH_DISTRIBUTE + // for output + std::vector> dist_inputs; + std::vector output_dims; auto result = PrepareCtxForAutoParallel(op_info, is_forward, is_double_grad, ctx); bool run_auto_parallel = std::get<0>(result); @@ -766,8 +805,12 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info, ctx.AssignInplaceOutputs(); #ifdef PADDLE_WITH_DISTRIBUTE - TransCtxTensorsToDistTensors( - ctx, run_auto_parallel, current_process_mesh, spmd_info); + TransCtxTensorsToDistTensors(ctx, + run_auto_parallel, + current_process_mesh, + spmd_info, + dist_inputs, + output_dims); #endif } diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc index fff9af10339a60..2070c6227438db 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.cc @@ -162,6 +162,21 @@ DistTensor::DistTensor(const std::shared_ptr& local_value, } } +DistTensor::DistTensor(const std::shared_ptr& local_value, + const DDim& global_dims, + const TensorDistAttr& dist_attr) + : global_dims_(global_dims), dist_attr_(dist_attr) { + process_mesh_ = dist_attr_.process_mesh(); + placements_ = ToPlacements(dist_attr); + if (IsCurRankInMesh(process_mesh_)) { + value_ = local_value; + } else { + value_ = std::make_shared( + std::make_shared(nullptr, 0, local_value->place()), + phi::DenseTensorMeta(local_value->dtype(), global_dims_)); + } +} + DistTensor::DistTensor(const std::shared_ptr& global_value, const ProcessMesh& process_mesh, const Placements& placements) diff --git a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h index 55b35ffe5c25a5..5ad10c76b25087 100644 --- a/paddle/phi/core/distributed/auto_parallel/dist_tensor.h +++ b/paddle/phi/core/distributed/auto_parallel/dist_tensor.h @@ -58,6 +58,13 @@ class DistTensor final const ProcessMesh& process_mesh, const Placements& placements); + /// \brief Construct a dist tensor based local dense tensor. + /// \param global_dims The global dim of the dist tensor. + /// \param dist_attr The distributed attributes of the current tensor. + DistTensor(const std::shared_ptr& local_value, + const DDim& global_dims, + const TensorDistAttr& dist_attr); + /// \brief Construct a dist tensor based local dense tensor. /// \param global_dims The global dim of the dist tensor. /// \param process_mesh The process mesh of the current tensor.