Skip to content

Commit

Permalink
[AutoParallel] fix spmd layer norm (#62439)
Browse files Browse the repository at this point in the history
* [AutoParallel] fix spmd layer norm

* layer norm

* fix
  • Loading branch information
liym27 authored Mar 12, 2024
1 parent 62b26cb commit b3c9734
Showing 1 changed file with 52 additions and 11 deletions.
63 changes: 52 additions & 11 deletions paddle/phi/infermeta/spmd_rules/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,26 @@ namespace distributed {

using phi::distributed::auto_parallel::str_join;

void LogInputDistAttr(const std::string& name,
const std::vector<int64_t>& shape,
const TensorDistAttr& src_dist_attr,
const TensorDistAttr& dst_dist_attr) {
VLOG(4) << name << " shape: [" << str_join(shape) << "] "
<< "src_dims_mapping: [" << str_join(src_dist_attr.dims_mapping())
<< "] "
<< "dst_dims_mapping: [" << str_join(dst_dist_attr.dims_mapping())
<< "] "
<< "src_partial: " << src_dist_attr.partial_status_string()
<< " dst_partial: " << dst_dist_attr.partial_status_string();
}

void LogOutputDistAttr(const std::string& name,
const TensorDistAttr& dst_dist_attr) {
VLOG(4) << name << " dims mapping: ["
<< str_join(dst_dist_attr.dims_mapping()) << "] "
<< "partial: " << dst_dist_attr.partial_status_string();
}

SpmdInfo LayerNormInferSpmd(const DistMetaTensor& x,
const DistMetaTensor& scale,
const DistMetaTensor& bias,
Expand Down Expand Up @@ -347,12 +367,16 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x,
TensorDistAttr x_dist_attr;
TensorDistAttr mean_dist_attr;
TensorDistAttr variance_dist_attr;
TensorDistAttr grad_dist_attr;
TensorDistAttr out_grad_dist_attr;

std::vector<TensorDistAttr> dist_attrs;
dist_attrs.push_back(x.dist_attr());
dist_attrs.push_back(mean.dist_attr());
dist_attrs.push_back(variance.dist_attr());
dist_attrs.push_back(out_grad.dist_attr());
out_grad_dist_attr = out_grad.dist_attr();
out_grad_dist_attr.clean_partial_status();
dist_attrs.push_back(out_grad_dist_attr);

if (begin_norm_axis > 0) {
std::vector<std::vector<int64_t>> shapes = {
x_shape, mean_shape, variance_shape, x_shape};
Expand All @@ -365,16 +389,17 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x,
x_dist_attr = std::move(dist_attrs[0]);
mean_dist_attr = std::move(dist_attrs[1]);
variance_dist_attr = std::move(dist_attrs[2]);
grad_dist_attr = std::move(dist_attrs[3]);
out_grad_dist_attr = std::move(dist_attrs[3]);
} else {
x_dist_attr = GetReplicatedDistAttr(dist_attrs[0]);
mean_dist_attr = GetReplicatedDistAttr(dist_attrs[1]);
variance_dist_attr = GetReplicatedDistAttr(dist_attrs[2]);
grad_dist_attr = GetReplicatedDistAttr(dist_attrs[3]);
out_grad_dist_attr = GetReplicatedDistAttr(dist_attrs[3]);
}
// TODO(liuzhenhai): support sharded scale and bias
TensorDistAttr scale_dist_attr = GetReplicatedDistAttr(scale.dist_attr());
TensorDistAttr bias_dist_attr = GetReplicatedDistAttr(bias.dist_attr());
TensorDistAttr x_grad_dist_attr = out_grad_dist_attr;
TensorDistAttr scale_grad_dist_attr =
GetReplicatedDistAttr(scale.dist_attr());
TensorDistAttr bias_grad_dist_attr = GetReplicatedDistAttr(bias.dist_attr());
Expand All @@ -390,13 +415,29 @@ SpmdInfo LayerNormGradInferSpmd(const DistMetaTensor& x,
scale_grad_dist_attr.set_partial_status(partial_on_dims);
bias_grad_dist_attr.set_partial_status(partial_on_dims);

return SpmdInfo({x_dist_attr,
scale_dist_attr,
bias_dist_attr,
mean_dist_attr,
variance_dist_attr,
grad_dist_attr},
{grad_dist_attr, scale_grad_dist_attr, bias_grad_dist_attr});
VLOG(4) << "LayerNormGradInferSpmd:";
VLOG(4) << "begin_norm_axis: " << begin_norm_axis;
LogInputDistAttr("X", x_shape, x.dist_attr(), x_dist_attr);
LogInputDistAttr("Scale", scale_shape, scale.dist_attr(), scale_dist_attr);
LogInputDistAttr("Bias", bias_shape, bias.dist_attr(), bias_dist_attr);
LogInputDistAttr("Mean", mean_shape, mean.dist_attr(), mean_dist_attr);
LogInputDistAttr(
"Variance", variance_shape, variance.dist_attr(), variance_dist_attr);
LogInputDistAttr(
"OutGrad", out_grad_shape, out_grad.dist_attr(), out_grad_dist_attr);
LogOutputDistAttr("XGrad", x_grad_dist_attr);
LogOutputDistAttr("ScaleGrad", scale_grad_dist_attr);
LogOutputDistAttr("BiasGrad", bias_grad_dist_attr);
VLOG(4) << std::endl;

return SpmdInfo(
{x_dist_attr,
scale_dist_attr,
bias_dist_attr,
mean_dist_attr,
variance_dist_attr,
out_grad_dist_attr},
{x_grad_dist_attr, scale_grad_dist_attr, bias_grad_dist_attr});
}

} // namespace distributed
Expand Down

0 comments on commit b3c9734

Please sign in to comment.