From 5370a45b946ea5195b83df03d7f19de310e90c81 Mon Sep 17 00:00:00 2001 From: DrRyanHuang Date: Fri, 15 Aug 2025 05:39:10 +0000 Subject: [PATCH 1/2] skip check when runtime --- .../pir/dialect/op_generator/op_build_gen.py | 1 + paddle/phi/infermeta/multiary.cc | 32 ++++++++++++------- paddle/phi/infermeta/multiary.h | 3 +- 3 files changed, 23 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index c426d3325a0811..657ff43683560e 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -94,6 +94,7 @@ 'LegacyInterpolateInferMeta', 'NceInferMeta', 'PyramidHashInferMeta', + 'RmsNormInferMeta', 'SigmoidCrossEntropyWithLogitsInferMeta', 'StackInferMeta', 'WeightOnlyLinearInferMeta', diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index bb10157cfc69da..2d254651c52380 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -4944,26 +4944,34 @@ void RmsNormInferMeta(const MetaTensor& x, const float quant_min_bound, MetaTensor* out, MetaTensor* residual_out, - MetaTensor* inv_var) { + MetaTensor* inv_var, + MetaConfig config) { size_t x_dims_size = x.dims().size(); size_t normalized_dims = 1; + bool has_minus_one = false; for (size_t i = begin_norm_axis; i < x_dims_size; ++i) { normalized_dims *= x.dims().at(i); + has_minus_one |= (x.dims().at(i) == -1); } - if (normalized_dims != 0) { - PADDLE_ENFORCE_EQ(normalized_dims, - norm_weight.dims()[0], - common::errors::InvalidArgument( - "The normalized size of Input(X) must equal to be " - "the size of Weight, but received " - "normalized size of Input(X) is [%d], received size " - "of Weight is [%d]", - normalized_dims, - norm_weight.dims()[0])); - } + // NOTE: Although 'goto' is generally discouraged, its use here replaces two + // obscure if-statements, improves readability, and does not cross large code + // blocks or affect resource management. The jump is clear and safe in this + // context. + if (normalized_dims == 0) goto skip_check; + if (has_minus_one && !config.is_runtime) goto skip_check; + PADDLE_ENFORCE_EQ(normalized_dims, + norm_weight.dims()[0], + common::errors::InvalidArgument( + "The normalized size of Input(X) must equal to be " + "the size of Weight, but received " + "normalized size of Input(X) is [%d], received size " + "of Weight is [%d]", + normalized_dims, + norm_weight.dims()[0])); +skip_check: out->set_dims(x.dims()); if (quant_scale > 0) { diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 67027f75097f7e..224a1376902672 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -988,7 +988,8 @@ void RmsNormInferMeta(const MetaTensor& x, const float quant_min_bound, MetaTensor* out, MetaTensor* residual_out, - MetaTensor* inv_var); + MetaTensor* inv_var, + MetaConfig config = MetaConfig()); void RmspropInferMeta(const MetaTensor& param, const MetaTensor& mean_square, From 99b59a6ec7aee5fe1c8612f66374d43bd3e6397a Mon Sep 17 00:00:00 2001 From: DrRyanHuang Date: Fri, 15 Aug 2025 06:59:17 +0000 Subject: [PATCH 2/2] use if instead of goto --- paddle/phi/infermeta/multiary.cc | 32 +++++++++++++++----------------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 2d254651c52380..51af7a9c2fe168 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -4955,23 +4955,21 @@ void RmsNormInferMeta(const MetaTensor& x, has_minus_one |= (x.dims().at(i) == -1); } - // NOTE: Although 'goto' is generally discouraged, its use here replaces two - // obscure if-statements, improves readability, and does not cross large code - // blocks or affect resource management. The jump is clear and safe in this - // context. - if (normalized_dims == 0) goto skip_check; - if (has_minus_one && !config.is_runtime) goto skip_check; - - PADDLE_ENFORCE_EQ(normalized_dims, - norm_weight.dims()[0], - common::errors::InvalidArgument( - "The normalized size of Input(X) must equal to be " - "the size of Weight, but received " - "normalized size of Input(X) is [%d], received size " - "of Weight is [%d]", - normalized_dims, - norm_weight.dims()[0])); -skip_check: + bool skip_check = false; + if (normalized_dims == 0) skip_check = true; + if (has_minus_one && !config.is_runtime) skip_check = true; + + if (!skip_check) { + PADDLE_ENFORCE_EQ(normalized_dims, + norm_weight.dims()[0], + common::errors::InvalidArgument( + "The normalized size of Input(X) must equal to be " + "the size of Weight, but received " + "normalized size of Input(X) is [%d], received size " + "of Weight is [%d]", + normalized_dims, + norm_weight.dims()[0])); + } out->set_dims(x.dims()); if (quant_scale > 0) {