Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Semi-Auto] Adapt reshape spmd rule to phi #57573

Merged
merged 3 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

4 changes: 0 additions & 4 deletions paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/split_spmd_rule.h"
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/transpose_spmd_rule.h"
Expand Down Expand Up @@ -50,9 +49,6 @@ REGISTER_SPMD_RULE(split_with_num, SplitSPMDRule);
// transpose rule
REGISTER_SPMD_RULE(transpose, TransposeSPMDRule);

// reshape rule
REGISTER_SPMD_RULE(reshape, ReshapeSPMDRule);

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
23 changes: 22 additions & 1 deletion paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,28 @@ std::vector<int> InferSpmdContext::AttrAt(size_t idx) const {
} catch (paddle::bad_variant_access const& e) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in InferSpmd Context, the input attr type is "
"`%s`, but the expected attribute type is `bool`.",
"`%s`, but the expected attribute type is `std::vector<int>`.",
attrs_.at(idx).type().name()));
}
}

template <>
std::vector<int64_t> InferSpmdContext::AttrAt(size_t idx) const {
try {
auto attr = attrs_.at(idx);
if (attr.type() == typeid(std::vector<bool>)) {
std::vector<bool> val = PADDLE_GET_CONST(std::vector<bool>, attr);
return std::vector<int64_t>(val.begin(), val.end());
} else if (attr.type() == typeid(std::vector<int>)) {
std::vector<int> val = PADDLE_GET_CONST(std::vector<int>, attr);
return std::vector<int64_t>(val.begin(), val.end());
} else {
return PADDLE_GET_CONST(std::vector<int64_t>, attr);
}
} catch (paddle::bad_variant_access const& e) {
PADDLE_THROW(phi::errors::InvalidArgument(
"Attribute cast error in InferSpmd Context, the input attr type is "
"`%s`, but the expected attribute type is `std::vector<int64_t>`.",
attrs_.at(idx).type().name()));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h"
#include "paddle/phi/infermeta/spmd_rules/dim_trans.h"
#include <assert.h>
#include <cstdio>
#include <numeric>
#include <set>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/enforce.h"

namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {

static std::vector<DimTrans*> all_dim_trans;

Expand Down Expand Up @@ -289,10 +288,11 @@ void GetUsedInputDim(DimTrans* dim_trans, std::set<int64_t>* seen_dims) {
}

std::vector<std::vector<int64_t>> InferFromDimTrans(
const DistTensorSpec& input_spec, const std::vector<DimTrans*>& dim_trans) {
const std::vector<int64_t>& input_shape = input_spec.shape();
const std::vector<int64_t>& input_dims_mapping = input_spec.dims_mapping();
const ProcessMesh& mesh = input_spec.dist_attr().process_mesh();
const DistMetaTensor& input, const std::vector<DimTrans*>& dim_trans) {
std::vector<int64_t> input_shape = phi::vectorize(input.dims());
const std::vector<int64_t>& input_dims_mapping =
input.dist_attr().dims_mapping();
const ProcessMesh& mesh = input.dist_attr().process_mesh();
const std::vector<int64_t>& mesh_shape = mesh.shape();

std::set<int64_t> sharded_input_dims;
Expand Down Expand Up @@ -354,6 +354,5 @@ std::vector<std::vector<int64_t>> InferFromDimTrans(
return {new_input_dims_mapping, out_dims_mapping};
}

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ limitations under the License. */
#include <iostream>
#include <vector>

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
#include "paddle/phi/core/distributed/type_defs.h"

namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {

// This is a base class to describe how each dimension in output tensor
// is transformed from input tensor's axes. The transformation includes
Expand Down Expand Up @@ -153,8 +153,7 @@ DimTrans* make_split(DimTrans* dim,
// leftmost output split axis can be sharded when its shape can be divisible
// by the mesh dimension.
std::vector<std::vector<int64_t>> InferFromDimTrans(
const DistTensorSpec& input_spec, const std::vector<DimTrans*>& dim_trans);
const DistMetaTensor& input_spec, const std::vector<DimTrans*>& dim_trans);

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h"
#include "paddle/phi/infermeta/spmd_rules/reshape.h"
#include <numeric>
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dim_trans.h"

#include "glog/logging.h"

#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/infermeta/spmd_rules/dim_trans.h"
#include "paddle/phi/infermeta/spmd_rules/utils.h"

namespace paddle {
namespace phi {
namespace distributed {
namespace auto_parallel {

using phi::distributed::auto_parallel::str_join;

Expand Down Expand Up @@ -71,9 +76,9 @@ std::vector<DimTrans*> MakeReshapeDimTrans(
std::vector<int64_t> inferred_tgt_shape =
InferTargetShape(tgt_shape, total_elem_num_src);

int64_t src_idx = 0, tgt_idx = 0;
int64_t s, t;
int64_t src_len, tgt_len;
int src_idx = 0, tgt_idx = 0;
int s, t;
int src_len, tgt_len;
src_len = static_cast<int64_t>(src_shape.size());
tgt_len = static_cast<int64_t>(inferred_tgt_shape.size());
while (src_idx < src_len || tgt_idx < tgt_len) {
Expand Down Expand Up @@ -135,29 +140,27 @@ std::vector<DimTrans*> MakeReshapeDimTrans(
return ret;
}

//
std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward(
const std::vector<DistTensorSpec>& input_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Reshape Logic
int64_t ninputs = static_cast<int64_t>(input_specs.size());
SpmdInfo ReshapeInferSpmd(const DistMetaTensor& x,
const std::vector<int64_t>& shape) {
// Step0: Verify input args based on reshape logic
auto src_shape = phi::vectorize(x.dims());
int x_ndim = src_shape.size();
auto x_dist_attr_src = x.dist_attr();
std::vector<int64_t> x_dims_mapping = x_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in reshape must "
"be equal to 1, but got [%d].",
ninputs));
VerifySpecs(input_specs, "reshape");

// step1: build the transformation from
// original shape to target shape
std::vector<int64_t> src_shape = input_specs[0].shape();
std::vector<int64_t> tgt_shape =
ExtractAttr<std::vector<int64_t>>("shape", attrs);
x_ndim,
x_dims_mapping.size(),
phi::errors::InvalidArgument("The Tensor X's rank [%d] and X's "
"dims_mapping size [%d] are not matched.",
x_ndim,
x_dims_mapping.size()));

// Step1: Build the transformation from
// the original shape to the target shape

// handle the '0' values in target shape, '0' indicates
// that the target shape is equal to the source shape
std::vector<int64_t> tgt_shape(shape);
for (int64_t i = 0, n = static_cast<int64_t>(tgt_shape.size()); i < n; i++) {
if (tgt_shape[i] == 0) {
tgt_shape[i] = src_shape[i];
Expand All @@ -166,96 +169,88 @@ paddle::distributed::auto_parallel::ReshapeSPMDRule::InferForward(

std::vector<DimTrans*> trans = MakeReshapeDimTrans(src_shape, tgt_shape);

// step2: infer the dims mapping of input (if reshard is
// Step2: Infer the dims mapping of input (if reshard is
// needed) and output from the dimension transformation.
std::vector<std::vector<int64_t>> dims_mapping_vec =
InferFromDimTrans(input_specs[0], trans);
InferFromDimTrans(x, trans);

// step3: update the dist attributes of input
// and output with the inferred dims mapping
TensorDistAttr new_input_dist_attr(input_specs[0].dist_attr());
new_input_dist_attr.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr output_dist_attr(input_specs[0].dist_attr());
output_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
// Step3: Update the dist attributes of input
// and output with the inferred dims mapping.
TensorDistAttr x_dist_attr_dst(x_dist_attr_src);
x_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr out_dist_attr(x_dist_attr_src);
out_dist_attr.set_dims_mapping(dims_mapping_vec[1]);

VLOG(4) << "Reshape: input_shape: [" << str_join(src_shape)
<< "] output_shape: [" << str_join(tgt_shape) << "]";
VLOG(4) << "ReshapeInferSpmd: X shape: [" << str_join(src_shape)
<< "] Out shape: [" << str_join(tgt_shape) << "]";
VLOG(4) << "Transformation from input to output:";
for (int64_t i = 0, n = static_cast<int64_t>(trans.size()); i < n; i++) {
DimTrans* t = trans[i];
VLOG(4) << "\tOutput axis " << i << ": " << t->to_string();
VLOG(4) << "\tOut axis[" << i << "]: " << t->to_string();
}
VLOG(4) << "input_dims_mapping: [" << str_join(dims_mapping_vec[0])
<< "] output_dims_mapping: [" << str_join(dims_mapping_vec[1])
VLOG(4) << "X dims_mapping_src: [" << str_join(x_dims_mapping)
<< "] dims_mapping_dst: [" << str_join(dims_mapping_vec[0])
<< "]\n Out dims_mapping: [" << str_join(dims_mapping_vec[1])
<< "]\n\n";

CleanUp();

return {{new_input_dist_attr}, {output_dist_attr}};
return {{x_dist_attr_dst}, {out_dist_attr}};
}

std::pair<std::vector<TensorDistAttr>, std::vector<TensorDistAttr>>
paddle::distributed::auto_parallel::ReshapeSPMDRule::InferBackward(
const std::vector<DistTensorSpec>& input_specs,
const std::vector<DistTensorSpec>& output_specs,
const paddle::framework::AttributeMap& attrs) {
// step0: Verify Input Args Based on Reshape Logic
int64_t ninputs = input_specs.size();
int64_t noutputs = output_specs.size();
SpmdInfo ReshapeInferSpmdReverse(const DistMetaTensor& x,
const DistMetaTensor& out,
const std::vector<int64_t>& shape) {
// Step0: Verify input args based on reshape logic
auto x_shape = phi::vectorize(x.dims());
auto out_shape = phi::vectorize(out.dims());
int out_ndim = out_shape.size();
auto out_dist_attr_src = out.dist_attr();
std::vector<int64_t> out_dims_mapping = out_dist_attr_src.dims_mapping();
PADDLE_ENFORCE_EQ(
ninputs,
1,
phi::errors::InvalidArgument("The size of InputSpec in reshape must "
"be equal to 1, but got [%d].",
ninputs));
PADDLE_ENFORCE_EQ(
noutputs,
1,
phi::errors::InvalidArgument("The size of OutputSpec in reshape must "
"be equal to 1, but got [%d].",
noutputs));
VerifySpecs(output_specs, "reshape");

// step1: build the transformation from the output shape
// to original shape. Inferbackward infers the dims mapping
out_ndim,
out_dims_mapping.size(),
phi::errors::InvalidArgument("The Tensor Out's rank [%d] and Out's "
"dims_mapping size [%d] are not matched.",
out_ndim,
out_dims_mapping.size()));

// Step1: Build the transformation from the output shape
// to original shape. This function infers the dims mapping
// from output to input, we first get the transformation
// from output to input so that we can infer the dims mapping
// with the map from output axes to input axes.
// Shapes in Inferbackward don't contain -1 or 0, so they will
// not be modified and we can use ref here.
const std::vector<int64_t>& output_shape = output_specs[0].shape();
const std::vector<int64_t>& input_shape = input_specs[0].shape();
// Shapes in InferSpmdReverse don't contain -1 or 0, so they will
// not be modified and we can directly use them.
std::vector<DimTrans*> trans = MakeReshapeDimTrans(out_shape, x_shape);

std::vector<DimTrans*> trans = MakeReshapeDimTrans(output_shape, input_shape);

// step2: infer the dims mapping of input with
// Step2: Infer the dims mapping of input with
// output's dims_mapping and the transformation.
std::vector<std::vector<int64_t>> dims_mapping_vec =
InferFromDimTrans(output_specs[0], trans);
InferFromDimTrans(out, trans);

// step3: update the dist attributes of input
// Step3: Update the dist attributes of input
// and output with the inferred dims mapping
TensorDistAttr new_output_dist_attr(output_specs[0].dist_attr());
new_output_dist_attr.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr input_dist_attr(input_specs[0].dist_attr());
input_dist_attr.set_dims_mapping(dims_mapping_vec[1]);
TensorDistAttr out_dist_attr_dst(out_dist_attr_src);
out_dist_attr_dst.set_dims_mapping(dims_mapping_vec[0]);
TensorDistAttr x_dist_attr(x.dist_attr());
x_dist_attr.set_dims_mapping(dims_mapping_vec[1]);

VLOG(4) << "Reshape Inferbackward: output_shape: [" << str_join(output_shape)
<< "] input_shape: [" << str_join(input_shape) << "]";
VLOG(4) << "ReshapeInferSpmdReverse: Out shape: [" << str_join(out_shape)
<< "] X shape: [" << str_join(x_shape) << "]";
VLOG(4) << "Transformation from output to input:";
for (int64_t i = 0, n = trans.size(); i < n; i++) {
DimTrans* t = trans[i];
VLOG(4) << "\tInput axis " << i << ": " << t->to_string();
VLOG(4) << "\tX axis[" << i << "]: " << t->to_string();
}
VLOG(4) << "input_dims_mapping: [" << str_join(dims_mapping_vec[1])
<< "] output_dims_mapping: [" << str_join(dims_mapping_vec[0])
<< "]\n\n";
VLOG(4) << "Out dims_mapping_src: [" << str_join(out_dims_mapping) << "] "
<< "dims_mapping_dst: [" << str_join(dims_mapping_vec[0]) << "]";
VLOG(4) << "X dims_mapping: [" << str_join(dims_mapping_vec[1]) << "]\n\n";

CleanUp();

return {{input_dist_attr}, {new_output_dist_attr}};
return {{x_dist_attr}, {out_dist_attr_dst}};
}

} // namespace auto_parallel
} // namespace distributed
} // namespace paddle
} // namespace phi
Loading