Skip to content

Commit

Permalink
suport reshard
Browse files Browse the repository at this point in the history
  • Loading branch information
liuzhenhai93 committed Jan 5, 2024
1 parent a8df9cb commit 212ec21
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 77 deletions.
197 changes: 120 additions & 77 deletions paddle/fluid/eager/custom_operator/custom_operator_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -543,12 +543,80 @@ phi::distributed::SpmdInfo RunInferSpmdFn(
return spmd_info;
}

std::vector<std::vector<phi::DDim>> RunInferShapeFn(
const paddle::OpMetaInfo& op_info,
bool is_forward,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& inplace_map,
paddle::CustomOpKernelContext& ctx) { // NOLINT
auto& infer_shape_func = paddle::OpMetaInfoHelper::GetInferShapeFn(op_info);

std::vector<std::vector<phi::DDim>> out_dims;
if (infer_shape_func) {
out_dims =
RunInferShapeFunc(ctx, infer_shape_func, inputs, outputs, inplace_map);
} else {
if (is_forward) {
out_dims = RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map);
} else {
out_dims =
RunDefaultGradInferShapeFunc(ctx, inputs, outputs, is_double_grad);
}
}

PADDLE_ENFORCE_EQ(
out_dims.size(),
ctx.OutputRange().size(),
phi::errors::InvalidArgument(
"Custome op infer_shape return size should be %d, but got %d.",
ctx.OutputRange().size(),
out_dims.size()));

return out_dims;
}

std::vector<std::vector<phi::DataType>> RunInferDtypeFn(
const paddle::OpMetaInfo& op_info,
bool is_forward,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& inplace_map,
paddle::CustomOpKernelContext& ctx) { // NOLINT

auto& infer_dtype_func = paddle::OpMetaInfoHelper::GetInferDtypeFn(op_info);
std::vector<std::vector<phi::DataType>> out_dtypes;
if (infer_dtype_func) {
out_dtypes =
RunInferDtypeFunc(ctx, infer_dtype_func, inputs, outputs, inplace_map);
} else {
if (is_forward) {
out_dtypes = RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map);
} else {
out_dtypes =
RunDefaultGradInferDtypeFunc(ctx, inputs, outputs, is_double_grad);
}
}
PADDLE_ENFORCE_EQ(
out_dtypes.size(),
ctx.OutputRange().size(),
phi::errors::InvalidArgument(
"Custome op infer_dtype return size should be %d, but got %d.",
ctx.OutputRange().size(),
out_dtypes.size()));
return out_dtypes;
}

std::
tuple<bool, bool, phi::distributed::ProcessMesh, phi::distributed::SpmdInfo>
PrepareCtxForAutoParallel(const paddle::OpMetaInfo& op_info,
bool is_forward,
bool is_double_grad,
paddle::CustomOpKernelContext& ctx) { // NOLINT
PrepareCtxForAutoParallel(
const paddle::OpMetaInfo& op_info,
bool is_forward,
bool is_double_grad,
paddle::CustomOpKernelContext& ctx, // NOLINT
std::vector<std::shared_ptr<phi::distributed::DistTensor>>&
dist_inputs, // NOLINT
std::vector<phi::DDim>& output_dims) { // NOLINT
bool run_auto_parallel = false;
bool rank_is_in_current_mesh = true;
phi::distributed::ProcessMesh current_process_mesh;
Expand Down Expand Up @@ -590,83 +658,47 @@ std::
dev_ctx, x[i], spmd_info.first[i]);
all_inputs->at(i).set_impl(
std::make_shared<phi::DenseTensor>(dist_input_i->value()));
dist_inputs.emplace_back(dist_input_i);
}
} else {
auto& infer_shape_func =
paddle::OpMetaInfoHelper::GetInferShapeFn(op_info);
auto& infer_dtype_func =
paddle::OpMetaInfoHelper::GetInferDtypeFn(op_info);

std::vector<std::vector<phi::DDim>> out_dims;
if (infer_shape_func) {
out_dims = RunInferShapeFunc(
ctx, infer_shape_func, inputs, outputs, inplace_map);
} else {
if (is_forward) {
out_dims =
RunDefaultInferShapeFunc(ctx, inputs, outputs, inplace_map);
} else {
out_dims = RunDefaultGradInferShapeFunc(
ctx, inputs, outputs, is_double_grad);
}
}
}

std::vector<std::vector<phi::DataType>> out_dtypes;
if (infer_dtype_func) {
out_dtypes = RunInferDtypeFunc(
ctx, infer_dtype_func, inputs, outputs, inplace_map);
} else {
if (is_forward) {
out_dtypes =
RunDefaultInferDtypeFunc(ctx, inputs, outputs, inplace_map);
} else {
out_dtypes = RunDefaultGradInferDtypeFunc(
ctx, inputs, outputs, is_double_grad);
}
}
std::vector<std::vector<phi::DDim>> out_dims =
RunInferShapeFn(op_info, is_forward, inputs, outputs, inplace_map, ctx);

PADDLE_ENFORCE_EQ(
out_dims.size(),
ctx.OutputRange().size(),
phi::errors::InvalidArgument(
"Custome op infer_shape return size should be %d, but got %d.",
ctx.OutputRange().size(),
out_dims.size()));
std::vector<std::vector<phi::DataType>> out_dtypes =
RunInferDtypeFn(op_info, is_forward, inputs, outputs, inplace_map, ctx);

for (size_t i = 0; i < out_dims.size(); ++i) {
const auto& out_dim = out_dims.at(i);
const auto& out_dtype = out_dtypes.at(i);
const auto& pair = ctx.OutputRangeAt(i);
PADDLE_ENFORCE_EQ(
out_dtypes.size(),
ctx.OutputRange().size(),
phi::errors::InvalidArgument(
"Custome op infer_dtype return size should be %d, but got %d.",
ctx.OutputRange().size(),
out_dtypes.size()));

for (size_t i = 0; i < out_dims.size(); ++i) {
const auto& out_dim = out_dims.at(i);
const auto& out_dtype = out_dtypes.at(i);
const auto& pair = ctx.OutputRangeAt(i);
PADDLE_ENFORCE_EQ(
out_dim.size(),
pair.second - pair.first,
phi::errors::InvalidArgument("custome op infer_shape result[%d]'s "
"size should be %d, but got %d.",
i,
pair.second - pair.first,
out_dim.size()));
PADDLE_ENFORCE_EQ(
out_dtype.size(),
pair.second - pair.first,
phi::errors::InvalidArgument("custome op infer_shape result[%d]'s "
"size should be %d, but got %d.",
i,
pair.second - pair.first,
out_dtype.size()));

if (out_dim.size() == 1) {
out_dim.size(),
pair.second - pair.first,
phi::errors::InvalidArgument("custome op infer_shape result[%d]'s "
"size should be %d, but got %d.",
i,
pair.second - pair.first,
out_dim.size()));
PADDLE_ENFORCE_EQ(
out_dtype.size(),
pair.second - pair.first,
phi::errors::InvalidArgument("custome op infer_shape result[%d]'s "
"size should be %d, but got %d.",
i,
pair.second - pair.first,
out_dtype.size()));

if (out_dim.size() == 1) {
output_dims.emplace_back(out_dim[0]);
if (!rank_is_in_current_mesh) {
*(ctx.MutableOutputAt(pair.first)) = BuildEmptyDistPaddleTensor(
current_process_mesh, out_dim[0], out_dtype[0]);
} else {
for (size_t j = pair.first; j < pair.second; j++) {
}
} else {
for (size_t j = pair.first; j < pair.second; j++) {
output_dims.emplace_back(out_dim[j]);
if (!rank_is_in_current_mesh) {
*(ctx.MutableOutputAt(j)) = BuildEmptyDistPaddleTensor(
current_process_mesh, out_dim[j], out_dtype[j]);
}
Expand All @@ -686,7 +718,10 @@ void TransCtxTensorsToDistTensors(
paddle::CustomOpKernelContext& ctx, // NOLINT
bool run_auto_parallel,
const phi::distributed::ProcessMesh& current_process_mesh,
const phi::distributed::SpmdInfo& spmd_info) {
const phi::distributed::SpmdInfo& spmd_info,
std::vector<std::shared_ptr<phi::distributed::DistTensor>>&
dist_inputs, // NOLINT
std::vector<phi::DDim>& output_dims) { // NOLINT
if (run_auto_parallel) {
std::vector<Tensor>* output_all = ctx.AllMutableOutput();
for (size_t i = 0; i < output_all->size(); ++i) {
Expand All @@ -701,6 +736,7 @@ void TransCtxTensorsToDistTensors(
}
auto dist_t = std::make_shared<phi::distributed::DistTensor>(
std::dynamic_pointer_cast<phi::DenseTensor>(tensor.impl()),
output_dims[i],
dist_attr);
tensor.set_impl(dist_t);
}
Expand Down Expand Up @@ -735,6 +771,9 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info,
ctx.ConstructInplaceIndex(inputs, outputs, inplace_map);

#ifdef PADDLE_WITH_DISTRIBUTE
// for output
std::vector<std::shared_ptr<phi::distributed::DistTensor>> dist_inputs;
std::vector<phi::DDim> output_dims;
auto result =
PrepareCtxForAutoParallel(op_info, is_forward, is_double_grad, ctx);
bool run_auto_parallel = std::get<0>(result);
Expand Down Expand Up @@ -766,8 +805,12 @@ void run_custom_op_impl(const paddle::OpMetaInfo& op_info,
ctx.AssignInplaceOutputs();

#ifdef PADDLE_WITH_DISTRIBUTE
TransCtxTensorsToDistTensors(
ctx, run_auto_parallel, current_process_mesh, spmd_info);
TransCtxTensorsToDistTensors(ctx,
run_auto_parallel,
current_process_mesh,
spmd_info,
dist_inputs,
output_dims);
#endif
}

Expand Down
15 changes: 15 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,21 @@ DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& local_value,
}
}

DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& local_value,
const DDim& global_dims,
const TensorDistAttr& dist_attr)
: global_dims_(global_dims), dist_attr_(dist_attr) {
process_mesh_ = dist_attr_.process_mesh();
placements_ = ToPlacements(dist_attr);
if (IsCurRankInMesh(process_mesh_)) {
value_ = local_value;
} else {
value_ = std::make_shared<DenseTensor>(
std::make_shared<phi::Allocation>(nullptr, 0, local_value->place()),
phi::DenseTensorMeta(local_value->dtype(), global_dims_));
}
}

DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
const ProcessMesh& process_mesh,
const Placements& placements)
Expand Down
7 changes: 7 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,13 @@ class DistTensor final
const ProcessMesh& process_mesh,
const Placements& placements);

/// \brief Construct a dist tensor based local dense tensor.
/// \param global_dims The global dim of the dist tensor.
/// \param dist_attr The distributed attributes of the current tensor.
DistTensor(const std::shared_ptr<phi::DenseTensor>& local_value,
const DDim& global_dims,
const TensorDistAttr& dist_attr);

/// \brief Construct a dist tensor based local dense tensor.
/// \param global_dims The global dim of the dist tensor.
/// \param process_mesh The process mesh of the current tensor.
Expand Down

0 comments on commit 212ec21

Please sign in to comment.