From 60fe1c6c6d8038ad09ef52e89ca76ea70866817e Mon Sep 17 00:00:00 2001 From: Colin Mufan <76479709+MufanColin@users.noreply.github.com> Date: Tue, 13 Aug 2024 10:12:27 +0800 Subject: [PATCH] [Infer Symbolic Shape No.166][BUAA] `rms_norm` (#67294) * Finished rms_norm op * Resolved suggested changes --- .../multiary_infer_sym.cc | 46 ++++++++++++++++--- .../infer_symbolic_shape/multiary_infer_sym.h | 2 +- paddle/phi/ops/yaml/ops.yaml | 1 + 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index a038dcf20e1d0..d4a9eabc21cad 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -1418,12 +1418,46 @@ bool MovingAverageAbsMaxScale_OpInferSymbolicShape( // return true; // } -// bool RmsNormOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext -// *infer_context) { -// // pass -// return true; -// } +bool RmsNormOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const symbol::ShapeOrDataDimExprs &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + std::vector x_shape = x_shape_or_data.shape(); + size_t x_shape_size = x_shape.size(); + + symbol::DimExpr normalized_dims(1); + int begin_norm_axis = + op->attribute("begin_norm_axis").data(); + for (size_t i = begin_norm_axis; i < x_shape_size; ++i) { + normalized_dims = normalized_dims * x_shape[i]; + } + + const auto &norm_weight_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(3)); + const std::vector &norm_weight_dims = + norm_weight_shape.shape(); + + infer_context->AddEqualCstr(normalized_dims, norm_weight_dims[0]); + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); + + if (op->result(2)) { + std::vector inv_var_dims( + x_shape.begin(), x_shape.begin() + begin_norm_axis); + infer_context->SetShapeOrDataForValue( + op->result(2), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(inv_var_dims)}); + } + + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); + + return true; +} // bool RoiPoolOpInferSymbolicShape(pir::Operation *op, // pir::InferSymbolicShapeContext diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index b16b28e9218ae..095590eca991d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -74,7 +74,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MovingAverageAbsMaxScale_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(NearestInterp) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nce) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(PsroiPool) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(RmsNorm) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(RmsNorm) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(RoiPool) OP_DECLARE_INFER_SYMBOLIC_SHAPE(RoiAlign) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Stack) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 5a0328f1f949e..62d5e22bf990b 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3881,6 +3881,7 @@ optional : bias, residual, norm_bias, residual_out intermediate : inv_var backward : rms_norm_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : rmsprop_ args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false)