diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.cc b/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.cc deleted file mode 100644 index 81f25a8d6ed889..00000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.cc +++ /dev/null @@ -1,280 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h" - -#include "paddle/phi/core/distributed/auto_parallel/utils.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { -using phi::distributed::auto_parallel::str_join; -std::pair, std::vector> -LayerNormSPMDRule::InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) { - // step0: verify input args based on layer_norm logic - auto input_specs_size = input_specs.size(); - PADDLE_ENFORCE_EQ( - input_specs_size, - 3, - phi::errors::InvalidArgument( - "The size of InputSpec of layer_norm should be 3, but got [%d].", - input_specs_size)); - auto x_shape = input_specs[0].shape(); - auto scale_shape = input_specs[1].shape(); - auto bias_shape = input_specs[2].shape(); - int x_ndim = static_cast(x_shape.size()); - int scale_ndim = static_cast(scale_shape.size()); - int bias_ndim = static_cast(bias_shape.size()); - - PADDLE_ENFORCE_EQ( - scale_ndim, - 1, - phi::errors::InvalidArgument( - "The ndim of scale in layer_norm should be 1, but got [%d].", - scale_ndim)); - - PADDLE_ENFORCE_EQ( - bias_ndim, - 1, - phi::errors::InvalidArgument( - "The ndim of bias in layer_norm should be 1, but got [%d].", - bias_ndim)); - - auto x_dims_mapping = input_specs[0].dist_attr().dims_mapping(); - auto scale_dims_mapping = input_specs[1].dist_attr().dims_mapping(); - auto bias_dims_mapping = input_specs[2].dist_attr().dims_mapping(); - - auto x_dist_attr_src = input_specs[0].dist_attr(); - - std::vector input_dist_attrs; - input_dist_attrs.reserve(input_specs.size()); - - int begin_norm_axis = ExtractAttr("begin_norm_axis", attrs); - - VLOG(4) << "LayerNormSPMDRule InferForward Inputs: " - << "x shape: [" << str_join(x_shape) << "], x_dims_mapping: [" - << str_join(x_dims_mapping) << "]; scale shape: [" - << str_join(scale_shape) << "], scale_dims_mapping: [" - << str_join(scale_dims_mapping) << "]; bias shape: [" - << str_join(bias_shape) << "], bias_dims_mapping: [" - << str_join(bias_dims_mapping) << "]; begin_norm_axis: [" - << begin_norm_axis << "]; "; - - // step1: build Einsum Notation - // ijk,k,k->ijk,z,z (x,scale,bias->out,mean,variance, begin_norm_axis=2, z=ij) - // ijkl,y(kl),y(kl)->ijkl,z(ij),z(ij) (x,scale,bias->out,mean,variance, - // begin_norm_axis=2, z=ij, y=kl) - std::string x_axes = ""; - for (auto i = 0; i < x_ndim; ++i) { - x_axes += static_cast(static_cast('k') - begin_norm_axis + i); - } - - std::string scale_axes; - std::string bias_axes; - if (x_ndim - begin_norm_axis == 1) { - scale_axes = "k"; - bias_axes = "k"; - } else { - // z = x_axes.substr(begin_norm_axis, x_ndim - begin_norm_axis) - scale_axes = "y"; - bias_axes = "y"; - } - - std::string mean_axes; - std::string variance_axes; - if (begin_norm_axis > 1) { - mean_axes = "z"; - variance_axes = "z"; - } else { - mean_axes = "j"; - variance_axes = "j"; - } - - std::string out_axes = x_axes; - - VLOG(4) << "LayerNormSPMDRule build Einsum notation (x,scale,bias->out): [" - << x_axes << "," << scale_axes << "," << bias_axes << " --> " - << out_axes << "," << mean_axes << "," << variance_axes - << "](begin_norm_axis:" << begin_norm_axis - << ",y=" << x_axes.substr(begin_norm_axis, x_ndim - begin_norm_axis) - << ",z=" << x_axes.substr(0, begin_norm_axis) << ")."; - - // step2: Sharding Propogation - TensorDistAttr output_dist_attr_dst = - CopyTensorDistAttrForOutput(x_dist_attr_src); - TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); - TensorDistAttr mean_dist_attr_dst = - CopyTensorDistAttrForOutput(x_dist_attr_src); - TensorDistAttr varience_dist_attr_dst = - CopyTensorDistAttrForOutput(x_dist_attr_src); - std::vector out_dims_mapping; - out_dims_mapping.reserve(out_axes.size()); - - int64_t mean_shard_dim = -1; - // As the mean and variance in outputs are `flattened` from - // x[0:begin_norm_axis], only the first axis can be sharded, - // the axes 1 to begin_norm_axis-1 are set to be replicated. - std::vector x_dims_mapping_dst(x_ndim, -1); - x_dims_mapping_dst[0] = x_dims_mapping[0]; - for (int i = 0; i < x_ndim; ++i) { - if (i < begin_norm_axis) { - out_dims_mapping.push_back(x_dims_mapping_dst[i]); - // if ijk,k,k->ijk,z,z (x,scale,bias->out,mean,variance, - // begin_norm_axis=2, z=ij), and the dims_mapping of input is (0,1,-1), - // the mean and varience is sharded by dim 0 and 1, - // which is not supported currently. - mean_shard_dim = ShardingMergeForAxis( - mean_axes, mean_shard_dim, x_dims_mapping_dst[i]); - } else { - out_dims_mapping.push_back(-1); - } - } - output_dist_attr_dst.set_dims_mapping(out_dims_mapping); - mean_dist_attr_dst.set_dims_mapping({mean_shard_dim}); - varience_dist_attr_dst.set_dims_mapping({mean_shard_dim}); - - // step2.3: Merge and get Inputs' New Dims Mapping. - x_dist_attr_dst.set_dims_mapping(x_dims_mapping_dst); - input_dist_attrs.emplace_back(x_dist_attr_dst); - // TODO(zhiqiu): support shardding on scale and bias - // Now, apply replicating. - input_dist_attrs.emplace_back(ReplicatedOnMesh(input_specs[1].dist_attr())); - input_dist_attrs.emplace_back(ReplicatedOnMesh(input_specs[2].dist_attr())); - - // Step2.4. handle input and out tensor partial - // LayerNorm not support - - VLOG(4) << "LayerNormSPMDRule InferForward: " - << "X shape: [" << str_join(x_shape) << "], src_dims_mapping: [" - << str_join(x_dims_mapping) << "], dst_dims_mapping: [" - << str_join(x_dist_attr_dst.dims_mapping()) << "]; scale shape: [" - << str_join(scale_shape) << "], src_dims_mapping: [" - << str_join(scale_dims_mapping) << "], dst_dims_mapping: [" - << str_join(input_dist_attrs[1].dims_mapping()) << "]; bias shape: [" - << str_join(bias_shape) << "], src_dims_mapping: [" - << str_join(bias_dims_mapping) << "], dst_dims_mapping: [" - << str_join(input_dist_attrs[2].dims_mapping()) - << "]; out dims_mapping: [" << str_join(out_dims_mapping) - << "]; mean dims_mapping: [" << mean_shard_dim - << "]; varience dims_mapping: [" << mean_shard_dim - << "], partial_on_dims: []"; - - return {input_dist_attrs, - {output_dist_attr_dst, mean_dist_attr_dst, varience_dist_attr_dst}}; -} - -std::pair, std::vector> -LayerNormSPMDRule::InferBackward( - const std::vector& input_specs, - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) { - // step0: verify input args based on layer_norm logic - int64_t ninputs = input_specs.size(); - int64_t noutputs = output_specs.size(); - PADDLE_ENFORCE_EQ( - ninputs, - 3, - phi::errors::InvalidArgument( - "The size of InputSpec of layer_norm should be 3, but got [%d].", - ninputs)); - PADDLE_ENFORCE_EQ( - noutputs, - 3, - phi::errors::InvalidArgument( - "The size of InputSpec of layer_norm should be 3, but got [%d].", - noutputs)); - VerifySpecs(output_specs, "layer_norm_backward"); - - // step1: build Einsum Notation - // ijk,k,k->ijk,z,z (x,scale,bias->out,mean,variance, begin_norm_axis=2, z=ij) - // ijkl,y(kl),y(kl)->ijkl,z(ij),z(ij) (x,scale,bias->out,mean,variance, - // begin_norm_axis=2, z=ij, y=kl) - int begin_norm_axis = ExtractAttr("begin_norm_axis", attrs); - std::string alphabet = "ijklmnopqrstuvwxyz"; - int x_ndim = input_specs[0].shape().size(); - std::string x_axes = alphabet.substr(0, x_ndim); - // the axes after norm_axis should be replicated, - // so set their notation to '1'. - for (int i = 1; i < x_ndim; i++) { - x_axes[i] = '1'; - } - std::string out_axes = x_axes; - std::string mean_axes(1, '1'), varience_axes(1, '1'); - if (begin_norm_axis > 0) { - mean_axes[0] = out_axes[0]; - varience_axes[0] = out_axes[0]; - } - std::vector output_axes_vec; - output_axes_vec.emplace_back(out_axes); - output_axes_vec.emplace_back(mean_axes); - output_axes_vec.emplace_back(varience_axes); - - // step2: Sharding Propogation - // For the axes after norm_axis in both input and output tensors, - // set their dims mappings to -1. For the other axes, set input - // tensor's dims mapping the same as output tensor's dims mapping. - // step2.1 merge dims mappings of output, mean, variance. - std::vector>> axes_sharding_info; - axes_sharding_info = GetAxesDimsMappingPair(output_axes_vec, output_specs); - std::unordered_map axis_to_dim_map = - ShardingMergeForTensors(axes_sharding_info); - - // step2.2 infer input dims mapping - std::vector input_dims_mapping = - GetDimsMappingForAxes(x_axes, axis_to_dim_map); - std::vector input_dist_attrs; - for (int64_t i = 0; i < ninputs; i++) { - input_dist_attrs.emplace_back(input_specs[i].dist_attr()); - } - input_dist_attrs[0].set_dims_mapping(input_dims_mapping); - // set bias and scale to be replicated - input_dist_attrs[1].set_dims_mapping({-1}); - input_dist_attrs[2].set_dims_mapping({-1}); - - // step2.3 update output dims mappings with merged one - std::vector output_dist_attrs; - for (int64_t i = 0; i < noutputs; i++) { - output_dist_attrs.emplace_back(output_specs[i].dist_attr()); - output_dist_attrs[i].set_dims_mapping( - GetDimsMappingForAxes(output_axes_vec[i], axis_to_dim_map)); - } - - VLOG(4) << "LayerNormSPMDRule InferBackward:"; - VLOG(4) << "begin_norm_axis: " << begin_norm_axis; - for (int64_t i = 0; i < noutputs; i++) { - VLOG(4) << "Output" << std::to_string(i) << " shape: [" - << str_join(output_specs[i].shape()) << "] " - << "Einsum Notation: " << output_axes_vec[i] - << " src_dims_mapping: [" - << str_join(output_specs[i].dims_mapping()) << "] " - << "dst_dims_mapping: [" - << str_join(output_dist_attrs[i].dims_mapping()) << "]"; - } - - for (int64_t i = 0; i < ninputs; i++) { - VLOG(4) << "Input" << std::to_string(i) << " shape: [" - << str_join(input_specs[i].shape()) << "] " - << "Einsum Notation: " << std::string(i == 0 ? x_axes : "1") - << " dims_mapping: [" - << str_join(input_dist_attrs[i].dims_mapping()) << "]"; - } - VLOG(4) << std::endl; - - return {input_dist_attrs, output_dist_attrs}; -} - -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h deleted file mode 100644 index da40f3da5653f5..00000000000000 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h +++ /dev/null @@ -1,41 +0,0 @@ -/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include -#include - -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" - -namespace paddle { -namespace distributed { -namespace auto_parallel { - -class LayerNormSPMDRule : public SPMDRuleBase { - public: - std::pair, std::vector> - InferForward(const std::vector& input_specs, - const paddle::framework::AttributeMap& attrs) override; - - std::pair, std::vector> - InferBackward(const std::vector& input_specs, - const std::vector& output_specs, - const paddle::framework::AttributeMap& attrs) override; -}; -} // namespace auto_parallel -} // namespace distributed -} // namespace paddle diff --git a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h index 54ae4325b8a15a..71f939ffd37850 100644 --- a/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h +++ b/paddle/fluid/distributed/auto_parallel/spmd_rules/rules.h @@ -17,7 +17,6 @@ #include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/cross_entropy_with_softmax_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/embedding_spmd_rule.h" -#include "paddle/fluid/distributed/auto_parallel/spmd_rules/layer_norm_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/replicated_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/reshape_spmd_rule.h" #include "paddle/fluid/distributed/auto_parallel/spmd_rules/softmax_spmd_rule.h" @@ -29,9 +28,6 @@ namespace paddle { namespace distributed { namespace auto_parallel { -// layer_norm rule -REGISTER_SPMD_RULE(layer_norm, LayerNormSPMDRule); - // replicated rule REGISTER_SPMD_RULE(replicated, ReplicatedSPMDRule); diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc index 6e0c0f696fef4a..24030b5d0ffa80 100644 --- a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.cc @@ -53,6 +53,9 @@ AttrType InferSpmdContext::AttrAt(size_t idx) const { } } +template float InferSpmdContext::AttrAt(size_t idx) const; +template int InferSpmdContext::AttrAt(size_t idx) const; + template <> bool InferSpmdContext::AttrAt(size_t idx) const { try { diff --git a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h index 23b147a4bb3d7e..499c2340983a77 100644 --- a/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h +++ b/paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h @@ -153,6 +153,8 @@ struct InferSpmdFnImpl { // TODO(chenweihang): support other attr type later as needed PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(bool); + PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(int); + PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_ATTRIBUTE(float); PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF(std::vector); PD_SPECIALIZE_InferSpmdFnCallHelper_FOR_CONST_ATTRIBUTE_REF( std::vector); diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.cc b/paddle/phi/infermeta/spmd_rules/layer_norm.cc new file mode 100644 index 00000000000000..6befef19cfef1b --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.cc @@ -0,0 +1,282 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/infermeta/spmd_rules/layer_norm.h" + +#include "glog/logging.h" + +#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h" +#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h" +#include "paddle/phi/core/distributed/auto_parallel/utils.h" +#include "paddle/phi/infermeta/spmd_rules/utils.h" + +namespace phi { +namespace distributed { + +using phi::distributed::auto_parallel::str_join; + +SpmdInfo LayerNormInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& scale, + const DistMetaTensor& bias, + float epsilon, + int begin_norm_axis) { + // Step0: verify input args based on layer_norm logic + auto x_shape = phi::vectorize(x.dims()); + auto scale_shape = phi::vectorize(scale.dims()); + auto bias_shape = phi::vectorize(bias.dims()); + int x_ndim = x_shape.size(); + int scale_ndim = scale_shape.size(); + int bias_ndim = bias_shape.size(); + TensorDistAttr x_dist_attr_src = x.dist_attr(); + std::vector x_dims_mapping = x_dist_attr_src.dims_mapping(); + std::vector scale_dims_mapping = scale.dist_attr().dims_mapping(); + std::vector bias_dims_mapping = bias.dist_attr().dims_mapping(); + + PADDLE_ENFORCE_EQ( + scale_ndim, + 1, + phi::errors::InvalidArgument( + "The ndim of scale in layer_norm should be 1, but got [%d].", + scale_ndim)); + + PADDLE_ENFORCE_EQ( + bias_ndim, + 1, + phi::errors::InvalidArgument( + "The ndim of bias in layer_norm should be 1, but got [%d].", + bias_ndim)); + + // Step1: Build Einsum Notation + // ijk,k,k->ijk,z,z (x,scale,bias->out,mean,variance, begin_norm_axis=2, z=ij) + // ijkl,y(kl),y(kl)->ijkl,z(ij),z(ij) (x,scale,bias->out,mean,variance, + // begin_norm_axis=2, z=ij, y=kl) + std::string alphabet = "ijklmnopqrstuvwxyz"; + // get input notation + // Because the mean and variance is 'flattened' from + // x[0:begin_norm_axis], only the first axis of x can + // be sharded + std::string x_axes(x_ndim, '1'); + x_axes[0] = alphabet[0]; + std::string scale_axes(1, x_axes[x_ndim - 1]); + std::string bias_axes(1, x_axes[x_ndim - 1]); + + // get output notation + std::string out_axes = x_axes; + std::string mean_axes(1, '1'), variance_axes(1, '1'); + if (begin_norm_axis > 0) { + mean_axes[0] = out_axes[0]; + variance_axes[0] = out_axes[0]; + } + + // Step2: Sharding Propogation + // Step2.1: merge input sharding + // As the mean and variance in outputs are `flattened` from + // x[0:begin_norm_axis], only the first axis can be sharded, + // the axes 1 to begin_norm_axis-1 are set to be replicated. + std::fill(x_dims_mapping.begin() + 1, x_dims_mapping.end(), -1); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors({{x_axes, x_dims_mapping}}); + + // Step2.2: infer output dims mapping + TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); + TensorDistAttr mean_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src); + TensorDistAttr varience_dist_attr = + CopyTensorDistAttrForOutput(x_dist_attr_src); + out_dist_attr.set_dims_mapping( + GetDimsMappingForAxes(out_axes, axis_to_dim_map)); + mean_dist_attr.set_dims_mapping( + GetDimsMappingForAxes(mean_axes, axis_to_dim_map)); + varience_dist_attr.set_dims_mapping( + GetDimsMappingForAxes(variance_axes, axis_to_dim_map)); + + // Step2.3: update input dims mapping + TensorDistAttr x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src); + TensorDistAttr scale_dist_attr_dst = + CopyTensorDistAttrForOutput(scale.dist_attr()); + TensorDistAttr bias_dist_attr_dst = + CopyTensorDistAttrForOutput(bias.dist_attr()); + x_dist_attr_dst.set_dims_mapping(x_dims_mapping); + // TODO(zhiqiu): support shardding on scale and bias + // Now, apply replicating. + scale_dist_attr_dst.set_dims_mapping({-1}); + bias_dist_attr_dst.set_dims_mapping({-1}); + + // Step2.4. handle input and out tensor partial + // LayerNorm not support + VLOG(4) << "LayerNormInferSpmd:"; + VLOG(4) << "begin_norm_axis: " << begin_norm_axis; + VLOG(4) << "Einsum Notation: " << x_axes << "," << scale_axes << "," + << bias_axes << "-->" << out_axes << "," << mean_axes << "," + << variance_axes; + VLOG(4) << "X" + << " shape: [" << str_join(x_shape) << "] " + << "src_dims_mapping: [" << str_join(x_dist_attr_src.dims_mapping()) + << "] " + << "dst_dims_mapping: [" << str_join(x_dims_mapping) << "]"; + VLOG(4) << "Scale" + << " shape: [" << str_join(scale_shape) << "] " + << "src_dims_mapping: [" << str_join(scale_dims_mapping) << "] " + << "dst_dims_mapping: [" + << str_join(scale_dist_attr_dst.dims_mapping()) << "]"; + VLOG(4) << "Bias" + << " shape: [" << str_join(bias_shape) << "] " + << "src_dims_mapping: [" << str_join(bias_dims_mapping) << "] " + << "dst_dims_mapping: [" + << str_join(bias_dist_attr_dst.dims_mapping()) << "]"; + VLOG(4) << "Out dims mapping: [" << str_join(out_dist_attr.dims_mapping()) + << "]"; + VLOG(4) << "Mean dims mapping: [" << str_join(mean_dist_attr.dims_mapping()) + << "]"; + VLOG(4) << "Variance dims mapping: [" + << str_join(varience_dist_attr.dims_mapping()) << "]"; + VLOG(4) << std::endl; + + return {{x_dist_attr_dst, scale_dist_attr_dst, bias_dist_attr_dst}, + {out_dist_attr, mean_dist_attr, varience_dist_attr}}; +} + +SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& scale, + const DistMetaTensor& bias, + const DistMetaTensor& out, + const DistMetaTensor& mean, + const DistMetaTensor& variance, + float epsilon, + int begin_norm_axis) { + // Step0: Verify input args based on layer_norm logic + auto x_shape = phi::vectorize(x.dims()); + auto out_shape = phi::vectorize(out.dims()); + auto mean_shape = phi::vectorize(mean.dims()); + auto variance_shape = phi::vectorize(variance.dims()); + int x_ndim = x_shape.size(); + int out_ndim = out_shape.size(); + int mean_ndim = mean_shape.size(); + int variance_ndim = variance_shape.size(); + auto out_dist_attr_src = out.dist_attr(); + auto mean_dist_attr_src = mean.dist_attr(); + auto variance_dist_attr_src = variance.dist_attr(); + std::vector out_dims_mapping = out_dist_attr_src.dims_mapping(); + std::vector mean_dims_mapping = mean_dist_attr_src.dims_mapping(); + std::vector variance_dims_mapping = + variance_dist_attr_src.dims_mapping(); + PADDLE_ENFORCE_EQ( + out_ndim, + out_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor Out's rank [%d] and Out's " + "dims_mapping size [%d] are not matched.", + out_ndim, + out_dims_mapping.size())); + PADDLE_ENFORCE_EQ( + mean_ndim, + mean_dims_mapping.size(), + phi::errors::InvalidArgument("The Tensor Mean's rank [%d] and Mean's " + "dims_mapping size [%d] are not matched.", + mean_ndim, + mean_dims_mapping.size())); + PADDLE_ENFORCE_EQ(variance_ndim, + variance_dims_mapping.size(), + phi::errors::InvalidArgument( + "The Tensor Variance's rank [%d] and Variance's " + "dims_mapping size [%d] are not matched.", + variance_ndim, + variance_dims_mapping.size())); + // Step1: Build Einsum Notation + // ijk,k,k->ijk,z,z (x,scale,bias->out,mean,variance, begin_norm_axis=2, z=ij) + // ijkl,y(kl),y(kl)->ijkl,z(ij),z(ij) (x,scale,bias->out,mean,variance, + // begin_norm_axis=2, z=ij, y=kl) + std::string alphabet = "ijklmnopqrstuvwxyz"; + // the axes after norm_axis should be replicated, + // so set their notation to '1'. + std::string x_axes(x_ndim, '1'); + x_axes[0] = alphabet[0]; + std::string scale_axes(1, x_axes[x_ndim - 1]); + std::string bias_axes(1, x_axes[x_ndim - 1]); + + std::string out_axes = x_axes; + std::string mean_axes(1, '1'), variance_axes(1, '1'); + if (begin_norm_axis > 0) { + mean_axes[0] = out_axes[0]; + variance_axes[0] = out_axes[0]; + } + + // Step2: Sharding Propogation + // For the axes after norm_axis in both input and output tensors, + // set their dims mappings to -1. For the other axes, set input + // tensor's dims mapping the same as output tensor's dims mapping. + // step2.1 merge dims mappings of output, mean, variance. + std::vector>> axes_sharding_info; + axes_sharding_info.emplace_back(std::make_pair(out_axes, out_dims_mapping)); + axes_sharding_info.emplace_back(std::make_pair(mean_axes, mean_dims_mapping)); + axes_sharding_info.emplace_back( + std::make_pair(variance_axes, variance_dims_mapping)); + std::unordered_map axis_to_dim_map = + ShardingMergeForTensors(axes_sharding_info); + + // Step2.2 infer input dims mapping + std::vector x_dims_mapping = + GetDimsMappingForAxes(x_axes, axis_to_dim_map); + std::vector input_dist_attrs; + input_dist_attrs.emplace_back(x.dist_attr()); + input_dist_attrs.emplace_back(scale.dist_attr()); + input_dist_attrs.emplace_back(bias.dist_attr()); + + input_dist_attrs[0].set_dims_mapping(x_dims_mapping); + // set bias and scale to be replicated + input_dist_attrs[1].set_dims_mapping({-1}); + input_dist_attrs[2].set_dims_mapping({-1}); + + // Step2.3 Update output dims mappings with merged one + std::vector output_dist_attrs; + output_dist_attrs.emplace_back(out_dist_attr_src); + output_dist_attrs.emplace_back(mean_dist_attr_src); + output_dist_attrs.emplace_back(variance_dist_attr_src); + output_dist_attrs[0].set_dims_mapping( + GetDimsMappingForAxes(out_axes, axis_to_dim_map)); + output_dist_attrs[1].set_dims_mapping( + GetDimsMappingForAxes(mean_axes, axis_to_dim_map)); + output_dist_attrs[2].set_dims_mapping( + GetDimsMappingForAxes(variance_axes, axis_to_dim_map)); + + VLOG(4) << "LayerNormInferSpmdReverse:"; + VLOG(4) << "begin_norm_axis: " << begin_norm_axis; + VLOG(4) << "Einsum Notation: " << x_axes << "," << scale_axes << "," + << bias_axes << "-->" << out_axes << "," << mean_axes << "," + << variance_axes; + VLOG(4) << "Out" + << " shape: [" << str_join(out_shape) << "] " + << " src_dims_mapping: [" << str_join(out_dims_mapping) << "] " + << "dst_dims_mapping: [" + << str_join(output_dist_attrs[0].dims_mapping()) << "]"; + VLOG(4) << "Mean" + << " shape: [" << str_join(mean_shape) << "] " + << " src_dims_mapping: [" << str_join(mean_dims_mapping) << "] " + << "dst_dims_mapping: [" + << str_join(output_dist_attrs[1].dims_mapping()) << "]"; + VLOG(4) << "Variance" + << " shape: [" << str_join(variance_shape) << "] " + << " src_dims_mapping: [" << str_join(variance_dims_mapping) << "] " + << "dst_dims_mapping: [" + << str_join(output_dist_attrs[2].dims_mapping()) << "]"; + + for (int i = 0, n = input_dist_attrs.size(); i < n; i++) { + VLOG(4) << "Input" << std::to_string(i) << " dims_mapping: [" + << str_join(input_dist_attrs[i].dims_mapping()) << "]"; + } + VLOG(4) << std::endl; + + return {input_dist_attrs, output_dist_attrs}; +} + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/layer_norm.h b/paddle/phi/infermeta/spmd_rules/layer_norm.h new file mode 100644 index 00000000000000..c33b58a51bc202 --- /dev/null +++ b/paddle/phi/infermeta/spmd_rules/layer_norm.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h" +#include "paddle/phi/core/distributed/type_defs.h" + +namespace phi { +namespace distributed { + +SpmdInfo LayerNormInferSpmd(const DistMetaTensor& x, + const DistMetaTensor& scale, + const DistMetaTensor& bias, + float epsilon, + int begin_norm_axis); + +SpmdInfo LayerNormInferSpmdReverse(const DistMetaTensor& x, + const DistMetaTensor& scale, + const DistMetaTensor& bias, + const DistMetaTensor& out, + const DistMetaTensor& mean, + const DistMetaTensor& variance, + float epsilon, + int begin_norm_axis); + +} // namespace distributed +} // namespace phi diff --git a/paddle/phi/infermeta/spmd_rules/rules.h b/paddle/phi/infermeta/spmd_rules/rules.h index 71a726e3d8edc3..cb01b8996a8c91 100644 --- a/paddle/phi/infermeta/spmd_rules/rules.h +++ b/paddle/phi/infermeta/spmd_rules/rules.h @@ -18,6 +18,7 @@ limitations under the License. */ #include "paddle/phi/infermeta/spmd_rules/default_data_parallel.h" #include "paddle/phi/infermeta/spmd_rules/elementwise.h" +#include "paddle/phi/infermeta/spmd_rules/layer_norm.h" #include "paddle/phi/infermeta/spmd_rules/matmul.h" #include "paddle/phi/infermeta/spmd_rules/reduction.h" #include "paddle/phi/infermeta/spmd_rules/replicated.h" @@ -457,5 +458,11 @@ PD_REGISTER_SPMD_RULE( PD_INFER_SPMD(phi::distributed::ReductionInferSpmd), PD_INFER_SPMD(phi::distributed::ReductionInferSpmdReverse)); +// layer_norm +PD_REGISTER_SPMD_RULE( + layer_norm, + PD_INFER_SPMD(phi::distributed::LayerNormInferSpmd), + PD_INFER_SPMD(phi::distributed::LayerNormInferSpmdReverse)); + } // namespace distributed } // namespace phi diff --git a/test/auto_parallel/spmd_rules/test_layer_norm_rule.py b/test/auto_parallel/spmd_rules/test_layer_norm_rule.py index bac7d12f13b06c..9af336fd8d2143 100644 --- a/test/auto_parallel/spmd_rules/test_layer_norm_rule.py +++ b/test/auto_parallel/spmd_rules/test_layer_norm_rule.py @@ -13,13 +13,14 @@ # limitations under the License. import unittest +from collections import OrderedDict -from paddle.distributed.auto_parallel.static.completion import get_spmd_rule from paddle.distributed.auto_parallel.static.dist_attribute import ( DistTensorSpec, TensorDistAttr, ) from paddle.distributed.fleet import auto +from paddle.framework import core class TestLayerNormSPMDRule(unittest.TestCase): @@ -28,7 +29,7 @@ class TestLayerNormSPMDRule(unittest.TestCase): """ def setUp(self): - self.rule = get_spmd_rule("layer_norm") + self.rule = core.get_phi_spmd_rule("layer_norm") x_shape = [64, 32, 1024] scale_shape = [1024] @@ -51,9 +52,7 @@ def setUp(self): self.mean_spec = DistTensorSpec(self.x_spec) self.var_spec = DistTensorSpec(self.x_spec) - self.attrs = { - 'begin_norm_axis': 2, - } + self.attrs = OrderedDict([('epsilon', 1e-3), ('begin_norm_axis', 2)]) def test_infer_forward(self): # ijk[1, -1, -1], k[-1], k[-1] --> @@ -65,7 +64,11 @@ def test_infer_forward(self): self.scale_spec.set_dims_mapping([-1]) result_dist_attrs = self.rule.infer_forward( - [self.x_spec, self.scale_spec, self.bias_spec], self.attrs + self.x_spec, + self.scale_spec, + self.bias_spec, + self.attrs['epsilon'], + self.attrs['begin_norm_axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -90,7 +93,11 @@ def test_infer_forward(self): self.bias_spec.set_dims_mapping([0]) result_dist_attrs = self.rule.infer_forward( - [self.x_spec, self.scale_spec, self.bias_spec], self.attrs + self.x_spec, + self.scale_spec, + self.bias_spec, + self.attrs['epsilon'], + self.attrs['begin_norm_axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -119,7 +126,11 @@ def test_infer_forward(self): self.bias_spec.set_dims_mapping([1]) result_dist_attrs = self.rule.infer_forward( - [self.x_spec, self.scale_spec, self.bias_spec], self.attrs + self.x_spec, + self.scale_spec, + self.bias_spec, + self.attrs['epsilon'], + self.attrs['begin_norm_axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -156,9 +167,14 @@ def test_infer_backward(self): self.var_spec.set_dims_mapping([1]) result_dist_attrs = self.rule.infer_backward( - [self.x_spec, self.scale_spec, self.bias_spec], - [self.out_spec, self.mean_spec, self.var_spec], - self.attrs, + self.x_spec, + self.scale_spec, + self.bias_spec, + self.out_spec, + self.mean_spec, + self.var_spec, + self.attrs['epsilon'], + self.attrs['begin_norm_axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -196,9 +212,14 @@ def test_infer_backward(self): self.var_spec.set_dims_mapping([0]) result_dist_attrs = self.rule.infer_backward( - [self.x_spec, self.scale_spec, self.bias_spec], - [self.out_spec, self.mean_spec, self.var_spec], - self.attrs, + self.x_spec, + self.scale_spec, + self.bias_spec, + self.out_spec, + self.mean_spec, + self.var_spec, + self.attrs['epsilon'], + self.attrs['begin_norm_axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -236,9 +257,14 @@ def test_infer_backward(self): self.var_spec.set_dims_mapping([-1]) result_dist_attrs = self.rule.infer_backward( - [self.x_spec, self.scale_spec, self.bias_spec], - [self.out_spec, self.mean_spec, self.var_spec], - self.attrs, + self.x_spec, + self.scale_spec, + self.bias_spec, + self.out_spec, + self.mean_spec, + self.var_spec, + self.attrs['epsilon'], + self.attrs['begin_norm_axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -276,9 +302,14 @@ def test_infer_backward(self): self.var_spec.set_dims_mapping([-1]) result_dist_attrs = self.rule.infer_backward( - [self.x_spec, self.scale_spec, self.bias_spec], - [self.out_spec, self.mean_spec, self.var_spec], - self.attrs, + self.x_spec, + self.scale_spec, + self.bias_spec, + self.out_spec, + self.mean_spec, + self.var_spec, + self.attrs['epsilon'], + self.attrs['begin_norm_axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -315,11 +346,16 @@ def test_infer_backward(self): self.mean_spec.set_dims_mapping([0]) self.var_spec.set_dims_mapping([-1]) - with self.assertRaises(BaseException): + with self.assertRaises(NotImplementedError): result_dist_attrs = self.rule.infer_backward( - [self.x_spec, self.scale_spec, self.bias_spec], - [self.out_spec, self.mean_spec, self.var_spec], - self.attrs, + self.x_spec, + self.scale_spec, + self.bias_spec, + self.out_spec, + self.mean_spec, + self.var_spec, + self.attrs['epsilon'], + self.attrs['begin_norm_axis'], ) # [-1, 1, -1], [0], [-1] (outputs) --> @@ -344,9 +380,14 @@ def test_infer_backward(self): self.var_spec.set_dims_mapping([-1]) result_dist_attrs = self.rule.infer_backward( - [self.x_spec, self.scale_spec, self.bias_spec], - [self.out_spec, self.mean_spec, self.var_spec], - self.attrs, + self.x_spec, + self.scale_spec, + self.bias_spec, + self.out_spec, + self.mean_spec, + self.var_spec, + self.attrs['epsilon'], + self.attrs['begin_norm_axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -384,9 +425,14 @@ def test_infer_backward(self): self.var_spec.set_dims_mapping([-1]) result_dist_attrs = self.rule.infer_backward( - [self.x_spec, self.scale_spec, self.bias_spec], - [self.out_spec, self.mean_spec, self.var_spec], - self.attrs, + self.x_spec, + self.scale_spec, + self.bias_spec, + self.out_spec, + self.mean_spec, + self.var_spec, + self.attrs['epsilon'], + self.attrs['begin_norm_axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] @@ -424,9 +470,14 @@ def test_infer_backward(self): self.var_spec.set_dims_mapping([-1]) result_dist_attrs = self.rule.infer_backward( - [self.x_spec, self.scale_spec, self.bias_spec], - [self.out_spec, self.mean_spec, self.var_spec], - self.attrs, + self.x_spec, + self.scale_spec, + self.bias_spec, + self.out_spec, + self.mean_spec, + self.var_spec, + self.attrs['epsilon'], + self.attrs['begin_norm_axis'], ) infered_input_dist_attrs = result_dist_attrs[0] infered_output_dist_attrs = result_dist_attrs[1] diff --git a/test/cpp/auto_parallel/spmd_rule_test.cc b/test/cpp/auto_parallel/spmd_rule_test.cc index e8f74513fc96f1..42476d7bb323ff 100644 --- a/test/cpp/auto_parallel/spmd_rule_test.cc +++ b/test/cpp/auto_parallel/spmd_rule_test.cc @@ -308,23 +308,28 @@ TEST(LayerNormSPMDRule, Ctor) { bias_dist_attr.set_dims_mapping(std::vector({-1})); bias_dist_attr.set_dynamic_dims(std::vector({false})); - DistTensorSpec x_dist_tensor_spec = DistTensorSpec(x_shape, x_dist_attr); - DistTensorSpec scale_dist_tensor_spec = - DistTensorSpec(scale_shape, scale_dist_attr); - DistTensorSpec bias_dist_tensor_spec = - DistTensorSpec(bias_shape, bias_dist_attr); - paddle::framework::AttributeMap attrs; - attrs["begin_norm_axis"] = 2; + float epsilon = 1e-5; + int begin_norm_axis = 2; - SPMDRuleBase* layer_norm_rule = SPMDRuleMap::Instance().Get("layer_norm"); + auto layer_norm_rule = + phi::distributed::SpmdRuleFactory::Instance().GetSpmdRule("layer_norm"); // ijk[1, -1, -1], k[-1], k[-1] --> ijk[1, -1, -1], z[1], z[1], z=ij, // begin_norm_axis=2 + begin_norm_axis = 2; + x_dist_attr.set_dims_mapping({1, -1, -1}); + scale_dist_attr.set_dims_mapping({-1}); + bias_dist_attr.set_dims_mapping({-1}); + phi::distributed::DistMetaTensor x(phi::make_ddim(x_shape), x_dist_attr); + phi::distributed::DistMetaTensor scale(phi::make_ddim(scale_shape), + scale_dist_attr); + phi::distributed::DistMetaTensor bias(phi::make_ddim(bias_shape), + bias_dist_attr); + phi::distributed::InferSpmdContext ctx({x, scale, bias}, + {epsilon, begin_norm_axis}); std::pair, std::vector> - infered_dist_attrs = layer_norm_rule->InferForward( - {x_dist_tensor_spec, scale_dist_tensor_spec, bias_dist_tensor_spec}, - attrs); + infered_dist_attrs = layer_norm_rule.InferForward(ctx); size_t input_size = 3; size_t output_size = 3; @@ -347,12 +352,18 @@ TEST(LayerNormSPMDRule, Ctor) { // ijk[1, 0, -1],k[0],k[0] --> ijk[1, -1, -1],z[1],z[1], // begin_norm_axis=2 - x_dist_tensor_spec.set_dims_mapping({1, 0, -1}); - scale_dist_tensor_spec.set_dims_mapping({0}); - bias_dist_tensor_spec.set_dims_mapping({0}); - infered_dist_attrs = layer_norm_rule->InferForward( - {x_dist_tensor_spec, scale_dist_tensor_spec, bias_dist_tensor_spec}, - attrs); + begin_norm_axis = 2; + x_dist_attr.set_dims_mapping({1, 0, -1}); + scale_dist_attr.set_dims_mapping({0}); + bias_dist_attr.set_dims_mapping({0}); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + scale = phi::distributed::DistMetaTensor(phi::make_ddim(scale_shape), + scale_dist_attr); + bias = phi::distributed::DistMetaTensor(phi::make_ddim(bias_shape), + bias_dist_attr); + ctx = phi::distributed::InferSpmdContext({x, scale, bias}, + {epsilon, begin_norm_axis}); + infered_dist_attrs = layer_norm_rule.InferForward(ctx); EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), std::vector({1, -1, -1})); EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(), @@ -369,13 +380,18 @@ TEST(LayerNormSPMDRule, Ctor) { // ijk[0, -1, -1],y[-1],y[1] --> ijk[0, 1, -1], i[0], i[0], y=jk, // begin_norm_axis=1 - x_dist_tensor_spec.set_dims_mapping({0, -1, -1}); - scale_dist_tensor_spec.set_dims_mapping({-1}); - bias_dist_tensor_spec.set_dims_mapping({1}); - attrs["begin_norm_axis"] = 1; - infered_dist_attrs = layer_norm_rule->InferForward( - {x_dist_tensor_spec, scale_dist_tensor_spec, bias_dist_tensor_spec}, - attrs); + begin_norm_axis = 1; + x_dist_attr.set_dims_mapping({0, -1, -1}); + scale_dist_attr.set_dims_mapping({-1}); + bias_dist_attr.set_dims_mapping({1}); + x = phi::distributed::DistMetaTensor(phi::make_ddim(x_shape), x_dist_attr); + scale = phi::distributed::DistMetaTensor(phi::make_ddim(scale_shape), + scale_dist_attr); + bias = phi::distributed::DistMetaTensor(phi::make_ddim(bias_shape), + bias_dist_attr); + ctx = phi::distributed::InferSpmdContext({x, scale, bias}, + {epsilon, begin_norm_axis}); + infered_dist_attrs = layer_norm_rule.InferForward(ctx); EXPECT_EQ(infered_dist_attrs.first[0].dims_mapping(), std::vector({0, -1, -1})); EXPECT_EQ(infered_dist_attrs.first[1].dims_mapping(),