diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index a35095c98d4a2..66f17168ec01a 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -22,6 +22,9 @@ namespace paddle { namespace primitive { namespace details { +// empty_shape means x.shape=[] +static std::vector empty_shape; + template Tensor mean_decomp(const Tensor& x, const IntArray& axis, bool keepdim) { auto org_dtype = x.dtype(); @@ -345,62 +348,66 @@ std::tuple layer_norm_decomp( // cast dtype to float32 if dtype =float16 or bfloat16 if (need_cast) { - x_cast = cast(x_cast, phi::DataType::FLOAT32); + x_cast = cast(x_cast, DataType::FLOAT32); } auto x_dim = common::vectorize(x.dims()); for (size_t i = begin_norm_axis; i < x_dim.size(); i++) { axis.push_back(static_cast(i)); } - auto mean_ = mean_decomp(x_cast, IntArray(axis), true); + auto mean_ = mean_decomp(x_cast, axis, true); auto difference = x_cast - mean_; auto var_tmp1 = difference * difference; - auto variance = mean_decomp(var_tmp1, IntArray(axis), true); + auto variance = mean_decomp(var_tmp1, axis, true); auto var_tmp3 = variance + epsilon; auto rsqrt_var = elementwise_pow( - var_tmp3, - full(common::vectorize(var_tmp3.dims()), -0.5, var_tmp3.dtype())); + var_tmp3, full(empty_shape, -0.5, var_tmp3.dtype())); auto out = difference * rsqrt_var; auto scale_ptr = scale.get_ptr(); auto bias_ptr = bias.get_ptr(); - std::vector slice_shape; - for (int64_t i = begin_norm_axis; i < static_cast(x_dim.size()); - i++) { - slice_shape.push_back(x_dim[i]); + std::vector slice_shape_l; + std::vector slice_shape_r; + for (int64_t i = 0; i < static_cast(x_dim.size()); i++) { + if (i < begin_norm_axis) { + slice_shape_l.push_back(x_dim[i]); + } else { + slice_shape_r.push_back(x_dim[i]); + } } Tensor scale_cast; if (scale_ptr) { - if (slice_shape != scale_ptr->shape()) { - scale_cast = reshape(*scale_ptr, slice_shape); + if (slice_shape_r != scale_ptr->shape()) { + scale_cast = reshape(*scale_ptr, slice_shape_r); } else { scale_cast = *scale_ptr; } if (need_cast) { - scale_cast = cast(scale_cast, phi::DataType::FLOAT32); + scale_cast = cast(scale_cast, DataType::FLOAT32); } out = out * scale_cast; } Tensor bias_cast; if (bias_ptr) { - if (slice_shape != bias_ptr->shape()) { - bias_cast = reshape(*bias_ptr, slice_shape); + if (slice_shape_r != bias_ptr->shape()) { + bias_cast = reshape(*bias_ptr, slice_shape_r); } else { bias_cast = *bias_ptr; } if (need_cast) { - bias_cast = cast(bias_cast, phi::DataType::FLOAT32); + bias_cast = cast(bias_cast, DataType::FLOAT32); } out = out + bias_cast; } - mean_ = reshape(mean_, std::vector({-1})); - variance = reshape(variance, std::vector({-1})); + mean_ = reshape(mean_, slice_shape_l); + variance = reshape(variance, slice_shape_l); + // same as LayerNormInferMeta + // x: float32 --> out: float32, mean: float32, variance: float32 + // x: float16 --> out: float16, mean: float32, variance: float32 if (need_cast) { out = cast(out, org_dtype); - mean_ = cast(mean_, org_dtype); - variance = cast(variance, org_dtype); } return std::make_tuple(out, mean_, variance);