Skip to content

Commit

Permalink
[Semi Auto] Refactor Replicated Rule (PaddlePaddle#56839)
Browse files Browse the repository at this point in the history
* adapt general spmd rule

* polish details

* add new rules

* bugfix for set_partial

* bugfix

* unitest

* adapt for argument for tensor and vector of tensor

---------

Co-authored-by: Chen Weihang <chenweihang@baidu.com>
  • Loading branch information
JZ-LIANG and chenwhql authored Sep 12, 2023
1 parent ca3fa62 commit e8bdafa
Show file tree
Hide file tree
Showing 15 changed files with 1,207 additions and 25 deletions.
77 changes: 77 additions & 0 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <utility>

#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
Expand Down Expand Up @@ -381,6 +382,44 @@ void BindAutoParallel(py::module *m) {
}
return self.InferForward(ctx);
})
.def("infer_forward", // for op that have vector argument
[](const phi::distributed::SpmdRule &self,
const std::vector<std::pair<int, int>> &input_ranges,
const std::vector<DistTensorSpec> &input_specs,
const std::vector<phi::Attribute> &attrs) {
/*
to distingish between single tensor argument and vector argument of
one tensor: start - end == 0: single tensor start - end == 1:
vector containing one tensor input_ranges: [(0, 0), (1, 3), (3, 4)]
+ input_specs: [t0, t1, t2, t3] --> t0, [t1, t2], [t3]
*/
phi::distributed::InferSpmdContext ctx;
paddle::small_vector<phi::distributed::DistMetaTensor,
phi::kInputSmallVectorSize>
ins;
for (auto &range : input_ranges) {
if (range.second - range.first == 0) {
auto &in = input_specs.at(range.first);
ctx.EmplaceBackInput(phi::distributed::DistMetaTensor(
phi::make_ddim(in.shape()), in.dist_attr()));
} else {
int start = range.first;
int end = range.second;
ins.reserve(end - start);
for (int i = start; i < end; ++i) {
auto &in = input_specs.at(i);
ins.emplace_back(phi::distributed::DistMetaTensor(
phi::make_ddim(in.shape()), in.dist_attr()));
}
ctx.EmplaceBackInputs(ins);
ins.clear();
}
}
for (auto &attr : attrs) {
ctx.EmplaceBackAttr(attr);
}
return self.InferForward(ctx);
})
.def("infer_backward",
[](const phi::distributed::SpmdRule &self,
const std::vector<DistTensorSpec> &input_specs,
Expand All @@ -399,6 +438,44 @@ void BindAutoParallel(py::module *m) {
ctx.EmplaceBackAttr(attr);
}
return self.InferBackward(ctx);
})
.def("infer_backward", // for op that have vector argument
[](const phi::distributed::SpmdRule &self,
const std::vector<std::pair<int, int>> &input_ranges,
const std::vector<DistTensorSpec> &input_specs,
const std::vector<phi::Attribute> &attrs) {
/*
to distingish between single tensor argument and vector argument of
one tensor: start - end == 0: single tensor start - end == 1:
vector containing one tensor input_ranges: [(0, 0), (1, 3), (3, 4)]
+ input_specs: [t0, t1, t2, t3] --> t0, [t1, t2], [t3]
*/
phi::distributed::InferSpmdContext ctx;
paddle::small_vector<phi::distributed::DistMetaTensor,
phi::kInputSmallVectorSize>
ins;
for (auto &range : input_ranges) {
if (range.second - range.first == 0) {
auto &in = input_specs.at(range.first);
ctx.EmplaceBackInput(phi::distributed::DistMetaTensor(
phi::make_ddim(in.shape()), in.dist_attr()));
} else {
int start = range.first;
int end = range.second;
ins.reserve(end - start);
for (int i = start; i < end; ++i) {
auto &in = input_specs.at(i);
ins.emplace_back(phi::distributed::DistMetaTensor(
phi::make_ddim(in.shape()), in.dist_attr()));
}
ctx.EmplaceBackInputs(ins);
ins.clear();
}
}
for (auto &attr : attrs) {
ctx.EmplaceBackAttr(attr);
}
return self.InferBackward(ctx);
});

py::class_<DistTensorSpec>(*m, "DistTensorSpec")
Expand Down
4 changes: 4 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ void TensorDistAttr::set_partial_status(const std::vector<int64_t>& dims,
"Trying to Set dim %d as Partial which is already a Partial dim.",
dim));
}
if (std::count(dims_mapping_.begin(), dims_mapping_.end(), dim)) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Trying to Set dim %d as Partial which is a Sharding dim.", dim));
}
partial_status_.emplace(dim, type);
}
}
Expand Down
28 changes: 28 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,18 @@ namespace phi {
namespace distributed {

void InferSpmdContext::EmplaceBackInput(DistMetaTensor input) {
int index = static_cast<int>(inputs_.size());
inputs_.emplace_back(std::move(input));
input_range_.emplace_back(std::pair<int, int>(index, index + 1));
}

void InferSpmdContext::EmplaceBackInputs(
paddle::small_vector<DistMetaTensor, phi::kInputSmallVectorSize> inputs) {
int index = static_cast<int>(inputs_.size());
input_range_.emplace_back(std::pair<int, int>(index, index + inputs.size()));
inputs_.insert(inputs_.end(),
std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
}

void InferSpmdContext::EmplaceBackAttr(Attribute attr) {
Expand Down Expand Up @@ -63,6 +74,23 @@ const Attribute& InferSpmdContext::AttrAt(size_t idx) const {
return attrs_.at(idx);
}

const std::pair<int, int>& InferSpmdContext::InputRangeAt(size_t idx) const {
return input_range_.at(idx);
}

const std::vector<const DistMetaTensor*> InferSpmdContext::InputsBetween(
size_t start, size_t end) const {
std::vector<const DistMetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
auto& in = inputs_.at(i);
result.emplace_back(&in);
// result.emplace_back(in.initialized() ? &in : nullptr);
}

return result;
}

SpmdRuleFactory& SpmdRuleFactory::Instance() {
static SpmdRuleFactory g_spmd_rule_map;
return g_spmd_rule_map;
Expand Down
27 changes: 27 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,15 @@ class InferSpmdContext {

void EmplaceBackInput(DistMetaTensor input);
void EmplaceBackAttr(Attribute attr);
void EmplaceBackInputs(
paddle::small_vector<DistMetaTensor, phi::kInputSmallVectorSize> inputs);

const DistMetaTensor& InputAt(size_t idx) const;

const std::pair<int, int>& InputRangeAt(size_t idx) const;
const std::vector<const DistMetaTensor*> InputsBetween(size_t start,
size_t end) const;

template <typename AttrType>
AttrType AttrAt(size_t idx) const;

Expand All @@ -59,6 +65,9 @@ class InferSpmdContext {
// Because the attribute arguments of dygraph do not have `attr name`,
// so we use vector instead of map
paddle::small_vector<Attribute, phi::kAttrSmallVectorSize> attrs_;
// for vector arguments
paddle::small_vector<std::pair<int, int>, phi::kInputSmallVectorSize>
input_range_;
};

using InferSpmdFn = SpmdInfo (*)(const InferSpmdContext&);
Expand Down Expand Up @@ -98,6 +107,24 @@ struct InferSpmdFnImpl<Return (*)(Args...), infer_spmd_fn> {
}
};

// for vecotr slot
template <typename... Tail>
struct InferSpmdFnCallHelper<const std::vector<const DistMetaTensor*>&,
Tail...> {
template <int in_idx, int attr_idx, typename... PreviousArgs>
static SpmdInfo Call(const InferSpmdContext& ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
"InferSpmd's Input should appear before Attributes.");

const std::pair<int, int> range = ctx.InputRangeAt(in_idx);
std::vector<const DistMetaTensor*> arg =
ctx.InputsBetween(range.first, range.second);
return InferSpmdFnCallHelper<Tail...>::template Call<in_idx + 1,
attr_idx>(
ctx, pargs..., arg);
}
};

#define PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(attr_type) \
template <typename... Tail> \
struct InferSpmdFnCallHelper<attr_type, Tail...> { \
Expand Down
164 changes: 164 additions & 0 deletions paddle/phi/infermeta/spmd_rules/default_data_parallel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h"

#include "glog/logging.h"

#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"

namespace phi {
namespace distributed {

using phi::distributed::auto_parallel::str_join;

////////////////// Utils Functions //////////////////
std::vector<int64_t> GetDefaultDataParallelDimsmapping(
const int64_t batch_axis_dim, const int ndim) {
std::vector<int64_t> dims_mapping(ndim, -1);
dims_mapping[0] = batch_axis_dim;
return dims_mapping;
}

////////////////// InferMeta(Contains SPMD) Functions //////////////////

SpmdInfo DefaultDataParallelSpmdInferForward(
const std::vector<const DistMetaTensor*>& ins,
const std::vector<const DistMetaTensor*>& outs) {
// step1: Build Einsum Notation for input tensor's batch axis
int64_t ninputs = ins.size();
int64_t noutputs = outs.size();
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
std::string batch_axis = "b";

for (int64_t i = 0; i < ninputs; ++i) {
axes_sharding_info.push_back(
{batch_axis, {ins[i]->dist_attr().dims_mapping()[0]}});
}

// Step2: Sharding Merge
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);
int64_t batch_axis_dim = axis_to_dim_map[batch_axis];

// Step3: Infer Output's Batch Axis Dims Mapping.
std::vector<TensorDistAttr> output_dist_attrs;
for (int64_t i = 0; i < noutputs; i++) {
int ndim = outs[i]->dims().size();
TensorDistAttr dist_attr_dst =
CopyTensorDistAttrForOutput(ins[0]->dist_attr());
std::vector<int64_t> dst_dims_maping =
GetDefaultDataParallelDimsmapping(batch_axis_dim, ndim);
dist_attr_dst.set_dims_mapping(dst_dims_maping);
output_dist_attrs.emplace_back(dist_attr_dst);
}

// Step4: Merge and get Inputs' Batch Axis New Dims Mapping.
std::vector<TensorDistAttr> dst_input_dist_attrs;
for (int64_t i = 0; i < ninputs; i++) {
int ndim = ins[i]->dims().size();
TensorDistAttr dist_attr_dst =
CopyTensorDistAttrForOutput(ins[i]->dist_attr());
std::vector<int64_t> dst_dims_maping =
GetDefaultDataParallelDimsmapping(batch_axis_dim, ndim);
dist_attr_dst.set_dims_mapping(dst_dims_maping);
dst_input_dist_attrs.emplace_back(dist_attr_dst);
}

VLOG(4) << "DefaultDataParallelSpmd InferForward:";
for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
<< str_join(phi::vectorize(ins[i]->dims())) << "] "
<< "src_dims_mapping: ["
<< str_join(ins[i]->dist_attr().dims_mapping()) << "] "
<< "dst_dims_mapping: ["
<< str_join(dst_input_dist_attrs[i].dims_mapping()) << "]";
}

for (int64_t i = 0; i < noutputs; i++) {
VLOG(4) << "Output" << std::to_string(i) << " shape: ["
<< str_join(phi::vectorize(outs[i]->dims())) << "] "
<< "dst_dims_mapping: ["
<< str_join(output_dist_attrs[i].dims_mapping()) << "]";
}

return {dst_input_dist_attrs, output_dist_attrs};
}
SpmdInfo DefaultDataParallelSpmdInferBackward(
const std::vector<const DistMetaTensor*>& ins,
const std::vector<const DistMetaTensor*>& outs) {
// step1: Build Einsum Notation for input tensor's batch axis
int64_t ninputs = ins.size();
int64_t noutputs = outs.size();
std::vector<std::pair<std::string, std::vector<int64_t>>> axes_sharding_info;
std::string batch_axis = "b";

for (int64_t i = 0; i < noutputs; ++i) {
axes_sharding_info.push_back(
{batch_axis, {outs[i]->dist_attr().dims_mapping()[0]}});
}

// Step2: Sharding Merge
std::unordered_map<std::string, int64_t> axis_to_dim_map =
ShardingMergeForTensors(axes_sharding_info);
int64_t batch_axis_dim = axis_to_dim_map[batch_axis];

// Step3: Infer Output's Batch Axis Dims Mapping.
std::vector<TensorDistAttr> output_dist_attrs;
for (int64_t i = 0; i < noutputs; i++) {
int ndim = outs[i]->dims().size();
TensorDistAttr dist_attr_dst =
CopyTensorDistAttrForOutput(outs[i]->dist_attr());
std::vector<int64_t> dst_dims_maping =
GetDefaultDataParallelDimsmapping(batch_axis_dim, ndim);
dist_attr_dst.set_dims_mapping(dst_dims_maping);
output_dist_attrs.emplace_back(dist_attr_dst);
}

// Step4: Merge and get Inputs' Batch Axis New Dims Mapping.
std::vector<TensorDistAttr> dst_input_dist_attrs;
for (int64_t i = 0; i < ninputs; i++) {
int ndim = ins[i]->dims().size();
TensorDistAttr dist_attr_dst =
CopyTensorDistAttrForOutput(ins[i]->dist_attr());
std::vector<int64_t> dst_dims_maping =
GetDefaultDataParallelDimsmapping(batch_axis_dim, ndim);
dist_attr_dst.set_dims_mapping(dst_dims_maping);
dst_input_dist_attrs.emplace_back(dist_attr_dst);
}

VLOG(4) << "DefaultDataParallelSpmd InferBackward:";
for (int64_t i = 0; i < noutputs; i++) {
VLOG(4) << "Output" << std::to_string(i) << " shape: ["
<< str_join(phi::vectorize(outs[i]->dims())) << "] "
<< "src_dims_mapping: ["
<< str_join(outs[i]->dist_attr().dims_mapping()) << "] "
<< "dst_dims_mapping: ["
<< str_join(output_dist_attrs[i].dims_mapping()) << "]";
}

for (int64_t i = 0; i < ninputs; i++) {
VLOG(4) << "Input" << std::to_string(i) << " shape: ["
<< str_join(phi::vectorize(ins[i]->dims())) << "] "
<< "dst_dims_mapping: ["
<< str_join(dst_input_dist_attrs[i].dims_mapping()) << "]";
}

return {dst_input_dist_attrs, output_dist_attrs};
}

} // namespace distributed
} // namespace phi
Loading

0 comments on commit e8bdafa

Please sign in to comment.