Skip to content

Commit

Permalink
[Semi-Auto] Adapt reshape spmd rule to phi (PaddlePaddle#57573)
Browse files Browse the repository at this point in the history
* adapt reshape rule spmd rule to phi

* fix the bug when op attribute is vector<int64_t> type

* add two more unit test cases
  • Loading branch information
pkuzyc authored Sep 25, 2023
1 parent 738674c commit b4d25f4
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 187 deletions.

This file was deleted.

4 changes: 0 additions & 4 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h"
Expand Down Expand Up @@ -50,9 +49,6 @@ REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule);
// transpose rule
REGISTER_SPMD_RULE(transpose, TransposeSPMDRule);

// reshape rule
REGISTER_SPMD_RULE(reshape, ReshapeSPMDRule);

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
23 changes: 22 additions & 1 deletion paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,28 @@ std::vector<int> InferSpmdContext::AttrAt(size_t idx) const {
} catch (paddle::bad_variant_access const& e) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in InferSpmd Context, the input attr type is "
"`%s`, but the expected attribute type is `bool`.",
"`%s`, but the expected attribute type is `std::vector<int>`.",
attrs_.at(idx).type().name()));
}
}

template <>
std::vector<int64_t> InferSpmdContext::AttrAt(size_t idx) const {
try {
auto attr = attrs_.at(idx);
if (attr.type() == typeid(std::vector<bool>)) {
std::vector<bool> val = PADDLE_GET_CONST(std::vector<bool>, attr);
return std::vector<int64_t>(val.begin(), val.end());
} else if (attr.type() == typeid(std::vector<int>)) {
std::vector<int> val = PADDLE_GET_CONST(std::vector<int>, attr);
return std::vector<int64_t>(val.begin(), val.end());
} else {
return PADDLE_GET_CONST(std::vector<int64_t>, attr);
}
} catch (paddle::bad_variant_access const& e) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in InferSpmd Context, the input attr type is "
"`%s`, but the expected attribute type is `std::vector<int64_t>`.",
attrs_.at(idx).type().name()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h"
#include "paddle/phi/infermeta/spmd_rules/dim_trans.h"
#include <assert.h>
#include <cstdio>
#include <numeric>
#include <set>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/enforce.h"

namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {

static std::vector<DimTrans*> all_dim_trans;

Expand Down Expand Up @@ -289,10 +288,11 @@ void GetUsedInputDim(DimTrans* dim_trans, std::set<int64_t>* seen_dims) {
}

std::vector<std::vector<int64_t>> InferFromDimTrans(
const DistTensorSpec& input_spec, const std::vector<DimTrans*>& dim_trans) {
const std::vector<int64_t>& input_shape = input_spec.shape();
const std::vector<int64_t>& input_dims_mapping = input_spec.dims_mapping();
const ProcessMesh& mesh = input_spec.dist_attr().process_mesh();
const DistMetaTensor& input, const std::vector<DimTrans*>& dim_trans) {
std::vector<int64_t> input_shape = phi::vectorize(input.dims());
const std::vector<int64_t>& input_dims_mapping =
input.dist_attr().dims_mapping();
const ProcessMesh& mesh = input.dist_attr().process_mesh();
const std::vector<int64_t>& mesh_shape = mesh.shape();

std::set<int64_t> sharded_input_dims;
Expand Down Expand Up @@ -354,6 +354,5 @@ std::vector<std::vector<int64_t>> InferFromDimTrans(
return {new_input_dims_mapping, out_dims_mapping};
}

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ limitations under the License. */
#include <iostream>
#include <vector>

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"

namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {

// This is a base class to describe how each dimension in output tensor
// is transformed from input tensor's axes. The transformation includes
Expand Down Expand Up @@ -153,8 +153,7 @@ DimTrans* make_split(DimTrans* dim,
// leftmost output split axis can be sharded when its shape can be divisible
// by the mesh dimension.
std::vector<std::vector<int64_t>> InferFromDimTrans(
const DistTensorSpec& input_spec, const std::vector<DimTrans*>& dim_trans);
const DistMetaTensor& input_spec, const std::vector<DimTrans*>& dim_trans);

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h"
#include "paddle/phi/infermeta/spmd_rules/reshape.h"
#include <numeric>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h"

#include "glog/logging.h"

#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/infermeta/spmd_rules/dim_trans.h"
#include "paddle/phi/infermeta/spmd_rules/utils.h"

namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {

using phi::distributed::auto_parallel::str_join;

Expand Down Expand Up @@ -71,9 +76,9 @@ std::vector<DimTrans*> MakeReshapeDimTrans(
std::vector<int64_t> inferred_tgt_shape =
InferTargetShape(tgt_shape, total_elem_num_src);

int64_t src_idx = 0, tgt_idx = 0;
int64_t s, t;
int64_t src_len, tgt_len;
int src_idx = 0, tgt_idx = 0;
int s, t;
int src_len, tgt_len;
src_len = static_cast<int64_t>(src_shape.size());
tgt_len = static_cast<int64_t>(inferred_tgt_shape.size());
while (src_idx < src_len || tgt_idx < tgt_len) {
Expand Down Expand Up @@ -135,29 +140,27 @@ std::vector<DimTrans*> MakeReshapeDimTrans(
return ret;
}

//
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Reshape Logic
int64_t ninputs = static_cast<int64_t>(input_specs.size());
SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x,
const std::vector<int64_t>& shape) {
// Step0: Verify input args based on reshape logic
auto src_shape = phi::vectorize(x.dims());
int x_ndim = src_shape.size();
auto x_dist_attr_src = x.dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in reshape must "
"be equal to 1, but got [%d].",
ninputs));
VerifySpecs(input_specs, "reshape");

// step1: build the transformation from
// original shape to target shape
std::vector<int64_t> src_shape = input_specs[0].shape();
std::vector<int64_t> tgt_shape =
ExtractAttr<std::vector<int64_t>>("shape", attrs);
x_ndim,
x_dims_mapping.size(),
phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's "
"dims_mapping size [%d] are not matched.",
x_ndim,
x_dims_mapping.size()));

// Step1: Build the transformation from
// the original shape to the target shape

// handle the '0' values in target shape, '0' indicates
// that the target shape is equal to the source shape
std::vector<int64_t> tgt_shape(shape);
for (int64_t i = 0, n = static_cast<int64_t>(tgt_shape.size()); i < n; i++) {
if (tgt_shape[i] == 0) {
tgt_shape[i] = src_shape[i];
Expand All @@ -166,96 +169,88 @@ paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward(

std::vector<DimTrans*> trans = MakeReshapeDimTrans(src_shape, tgt_shape);

// step2: infer the dims mapping of input (if reshard is
// Step2: Infer the dims mapping of input (if reshard is
// needed) and output from the dimension transformation.
std::vector<std::vector<int64_t>> dims_mapping_vec =
InferFromDimTrans(input_specs[0], trans);
InferFromDimTrans(x, trans);

// step3: update the dist attributes of input
// and output with the inferred dims mapping
TensorDistAttr new_input_dist_attr(input_specs[0].dist_attr());
new_input_dist_attr.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr output_dist_attr(input_specs[0].dist_attr());
output_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
// Step3: Update the dist attributes of input
// and output with the inferred dims mapping.
TensorDistAttr x_dist_attr_dst(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr out_dist_attr(x_dist_attr_src);
out_dist_attr.set_dims_mapping(dims_mapping_vec[1]);

VLOG(4) << "Reshape: input_shape: [" << str_join(src_shape)
<< "] output_shape: [" << str_join(tgt_shape) << "]";
VLOG(4) << "ReshapeInferSpmd: X shape: [" << str_join(src_shape)
<< "] Out shape: [" << str_join(tgt_shape) << "]";
VLOG(4) << "Transformation from input to output:";
for (int64_t i = 0, n = static_cast<int64_t>(trans.size()); i < n; i++) {
DimTrans* t = trans[i];
VLOG(4) << "\tOutput axis " << i << ": " << t->to_string();
VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string();
}
VLOG(4) << "input_dims_mapping: [" << str_join(dims_mapping_vec[0])
<< "] output_dims_mapping: [" << str_join(dims_mapping_vec[1])
VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping)
<< "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0])
<< "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1])
<< "]\n\n";

CleanUp();

return {{new_input_dist_attr}, {output_dist_attr}};
return {{x_dist_attr_dst}, {out_dist_attr}};
}

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
paddle::distributed::auto_parallel::ReshapeSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Reshape Logic
int64_t ninputs = input_specs.size();
int64_t noutputs = output_specs.size();
SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& out,
const std::vector<int64_t>& shape) {
// Step0: Verify input args based on reshape logic
auto x_shape = phi::vectorize(x.dims());
auto out_shape = phi::vectorize(out.dims());
int out_ndim = out_shape.size();
auto out_dist_attr_src = out.dist_attr();
std::vector<int64_t> out_dims_mapping = out_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in reshape must "
"be equal to 1, but got [%d].",
ninputs));
PADDLE_ENFORCE_EQ(
noutputs,
1,
phi::errors::InvalidArgument("The size of OutputSpec in reshape must "
"be equal to 1, but got [%d].",
noutputs));
VerifySpecs(output_specs, "reshape");

// step1: build the transformation from the output shape
// to original shape. Inferbackward infers the dims mapping
out_ndim,
out_dims_mapping.size(),
phi::errors::InvalidArgument("The Tensor Out's rank [%d] and Out's "
"dims_mapping size [%d] are not matched.",
out_ndim,
out_dims_mapping.size()));

// Step1: Build the transformation from the output shape
// to original shape. This function infers the dims mapping
// from output to input, we first get the transformation
// from output to input so that we can infer the dims mapping
// with the map from output axes to input axes.
// Shapes in Inferbackward don't contain -1 or 0, so they will
// not be modified and we can use ref here.
const std::vector<int64_t>& output_shape = output_specs[0].shape();
const std::vector<int64_t>& input_shape = input_specs[0].shape();
// Shapes in InferSpmdReverse don't contain -1 or 0, so they will
// not be modified and we can directly use them.
std::vector<DimTrans*> trans = MakeReshapeDimTrans(out_shape, x_shape);

std::vector<DimTrans*> trans = MakeReshapeDimTrans(output_shape, input_shape);

// step2: infer the dims mapping of input with
// Step2: Infer the dims mapping of input with
// output's dims_mapping and the transformation.
std::vector<std::vector<int64_t>> dims_mapping_vec =
InferFromDimTrans(output_specs[0], trans);
InferFromDimTrans(out, trans);

// step3: update the dist attributes of input
// Step3: Update the dist attributes of input
// and output with the inferred dims mapping
TensorDistAttr new_output_dist_attr(output_specs[0].dist_attr());
new_output_dist_attr.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr input_dist_attr(input_specs[0].dist_attr());
input_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
TensorDistAttr out_dist_attr_dst(out_dist_attr_src);
out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr x_dist_attr(x.dist_attr());
x_dist_attr.set_dims_mapping(dims_mapping_vec[1]);

VLOG(4) << "Reshape Inferbackward: output_shape: [" << str_join(output_shape)
<< "] input_shape: [" << str_join(input_shape) << "]";
VLOG(4) << "ReshapeInferSpmdReverse: Out shape: [" << str_join(out_shape)
<< "] X shape: [" << str_join(x_shape) << "]";
VLOG(4) << "Transformation from output to input:";
for (int64_t i = 0, n = trans.size(); i < n; i++) {
DimTrans* t = trans[i];
VLOG(4) << "\tInput axis " << i << ": " << t->to_string();
VLOG(4) << "\tX axis[" << i << "]: " << t->to_string();
}
VLOG(4) << "input_dims_mapping: [" << str_join(dims_mapping_vec[1])
<< "] output_dims_mapping: [" << str_join(dims_mapping_vec[0])
<< "]\n\n";
VLOG(4) << "Out dims_mapping_src: [" << str_join(out_dims_mapping) << "] "
<< "dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]";
VLOG(4) << "X dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n";

CleanUp();

return {{input_dist_attr}, {new_output_dist_attr}};
return {{x_dist_attr}, {out_dist_attr_dst}};
}

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
Loading

0 comments on commit b4d25f4

Please sign in to comment.