From 633479cc830d1691d0f5def908f81d193e063009 Mon Sep 17 00:00:00 2001 From: "haowen.han@mthreads.com" Date: Mon, 13 May 2024 13:29:35 +0800 Subject: [PATCH] Revert "fix layer_norm decompose dtyte bugs, polish codes (#61631)" This reverts commit e5a85b63b70a5c3ba6fabd775566c40e95f19388. --- paddle/fluid/primitive/composite/composite.h | 45 +++++++++----------- 1 file changed, 19 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 66f17168ec01a..a35095c98d4a2 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -22,9 +22,6 @@ namespace paddle { namespace primitive { namespace details { -// empty_shape means x.shape=[] -static std::vector empty_shape; - template Tensor mean_decomp(const Tensor& x, const IntArray& axis, bool keepdim) { auto org_dtype = x.dtype(); @@ -348,66 +345,62 @@ std::tuple layer_norm_decomp( // cast dtype to float32 if dtype =float16 or bfloat16 if (need_cast) { - x_cast = cast(x_cast, DataType::FLOAT32); + x_cast = cast(x_cast, phi::DataType::FLOAT32); } auto x_dim = common::vectorize(x.dims()); for (size_t i = begin_norm_axis; i < x_dim.size(); i++) { axis.push_back(static_cast(i)); } - auto mean_ = mean_decomp(x_cast, axis, true); + auto mean_ = mean_decomp(x_cast, IntArray(axis), true); auto difference = x_cast - mean_; auto var_tmp1 = difference * difference; - auto variance = mean_decomp(var_tmp1, axis, true); + auto variance = mean_decomp(var_tmp1, IntArray(axis), true); auto var_tmp3 = variance + epsilon; auto rsqrt_var = elementwise_pow( - var_tmp3, full(empty_shape, -0.5, var_tmp3.dtype())); + var_tmp3, + full(common::vectorize(var_tmp3.dims()), -0.5, var_tmp3.dtype())); auto out = difference * rsqrt_var; auto scale_ptr = scale.get_ptr(); auto bias_ptr = bias.get_ptr(); - std::vector slice_shape_l; - std::vector slice_shape_r; - for (int64_t i = 0; i < static_cast(x_dim.size()); i++) { - if (i < begin_norm_axis) { - slice_shape_l.push_back(x_dim[i]); - } else { - slice_shape_r.push_back(x_dim[i]); - } + std::vector slice_shape; + for (int64_t i = begin_norm_axis; i < static_cast(x_dim.size()); + i++) { + slice_shape.push_back(x_dim[i]); } Tensor scale_cast; if (scale_ptr) { - if (slice_shape_r != scale_ptr->shape()) { - scale_cast = reshape(*scale_ptr, slice_shape_r); + if (slice_shape != scale_ptr->shape()) { + scale_cast = reshape(*scale_ptr, slice_shape); } else { scale_cast = *scale_ptr; } if (need_cast) { - scale_cast = cast(scale_cast, DataType::FLOAT32); + scale_cast = cast(scale_cast, phi::DataType::FLOAT32); } out = out * scale_cast; } Tensor bias_cast; if (bias_ptr) { - if (slice_shape_r != bias_ptr->shape()) { - bias_cast = reshape(*bias_ptr, slice_shape_r); + if (slice_shape != bias_ptr->shape()) { + bias_cast = reshape(*bias_ptr, slice_shape); } else { bias_cast = *bias_ptr; } if (need_cast) { - bias_cast = cast(bias_cast, DataType::FLOAT32); + bias_cast = cast(bias_cast, phi::DataType::FLOAT32); } out = out + bias_cast; } - mean_ = reshape(mean_, slice_shape_l); - variance = reshape(variance, slice_shape_l); + mean_ = reshape(mean_, std::vector({-1})); + variance = reshape(variance, std::vector({-1})); - // same as LayerNormInferMeta - // x: float32 --> out: float32, mean: float32, variance: float32 - // x: float16 --> out: float16, mean: float32, variance: float32 if (need_cast) { out = cast(out, org_dtype); + mean_ = cast(mean_, org_dtype); + variance = cast(variance, org_dtype); } return std::make_tuple(out, mean_, variance);