From 855a17886d9cc8314b837cd39e5436c40fbfbf25 Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Fri, 9 Aug 2024 17:21:52 +0800 Subject: [PATCH 1/2] Finished rms_norm op --- .../multiary_infer_sym.cc | 46 ++++++++++++++++--- .../infer_symbolic_shape/multiary_infer_sym.h | 2 +- paddle/phi/ops/yaml/ops.yaml | 1 + 3 files changed, 42 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 84c3bc82ee3aa..4cb92c101c30b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -1257,12 +1257,46 @@ bool MeshgridOpInferSymbolicShape( // return true; // } -// bool RmsNormOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext -// *infer_context) { -// // pass -// return true; -// } +bool RmsNormOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const symbol::ShapeOrDataDimExprs &x_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + std::vector x_dims = x_shape.shape(); + size_t x_dims_size = x_dims.size(); + + symbol::DimExpr normalized_dims(1); + int begin_norm_axis = + op->attribute("begin_norm_axis").data(); + for (size_t i = begin_norm_axis; i < x_dims_size; ++i) { + normalized_dims = normalized_dims * x_dims[i]; + } + + const auto &norm_weight_shape = + infer_context->GetShapeOrDataForValue(op->operand_source(3)); + const std::vector norm_weight_dims = + norm_weight_shape.shape(); + + infer_context->AddEqualCstr(normalized_dims, norm_weight_dims[0]); + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_dims)}); + + if (op->result(2)) { + std::vector inv_var_dims(x_dims.begin(), + x_dims.begin() + begin_norm_axis); + infer_context->SetShapeOrDataForValue( + op->result(2), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(inv_var_dims)}); + } + + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_dims)}); + + return true; +} // bool RoiPoolOpInferSymbolicShape(pir::Operation *op, // pir::InferSymbolicShapeContext diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h index a06c91cfde013..4beecbdf213ec 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.h @@ -65,7 +65,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Meshgrid) OP_DECLARE_INFER_SYMBOLIC_SHAPE(NearestInterp) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nce) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(PsroiPool) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(RmsNorm) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(RmsNorm) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(RoiPool) OP_DECLARE_INFER_SYMBOLIC_SHAPE(RoiAlign) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Stack) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 82d6ddf58a5cf..4375be85569ea 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3845,6 +3845,7 @@ optional : bias, residual, norm_bias, residual_out intermediate : inv_var backward : rms_norm_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : rmsprop_ args : (Tensor param, Tensor mean_square, Tensor grad, Tensor moment, Tensor learning_rate, Tensor mean_grad, Tensor master_param, float epsilon = 1.0e-10f, float decay = 0.9f, float momentum = 0.0f, bool centered = false, bool multi_precision = false) From 5f08effced0147be9b880d93fb5e937aad33b2c9 Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Mon, 12 Aug 2024 11:21:47 +0800 Subject: [PATCH 2/2] Resolved suggested changes --- .../multiary_infer_sym.cc | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 4cb92c101c30b..57c757b086659 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -1259,32 +1259,32 @@ bool MeshgridOpInferSymbolicShape( bool RmsNormOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - const symbol::ShapeOrDataDimExprs &x_shape = + const symbol::ShapeOrDataDimExprs &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); - std::vector x_dims = x_shape.shape(); - size_t x_dims_size = x_dims.size(); + std::vector x_shape = x_shape_or_data.shape(); + size_t x_shape_size = x_shape.size(); symbol::DimExpr normalized_dims(1); int begin_norm_axis = op->attribute("begin_norm_axis").data(); - for (size_t i = begin_norm_axis; i < x_dims_size; ++i) { - normalized_dims = normalized_dims * x_dims[i]; + for (size_t i = begin_norm_axis; i < x_shape_size; ++i) { + normalized_dims = normalized_dims * x_shape[i]; } const auto &norm_weight_shape = infer_context->GetShapeOrDataForValue(op->operand_source(3)); - const std::vector norm_weight_dims = + const std::vector &norm_weight_dims = norm_weight_shape.shape(); infer_context->AddEqualCstr(normalized_dims, norm_weight_dims[0]); infer_context->SetShapeOrDataForValue( op->result(0), - symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_dims)}); + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); if (op->result(2)) { - std::vector inv_var_dims(x_dims.begin(), - x_dims.begin() + begin_norm_axis); + std::vector inv_var_dims( + x_shape.begin(), x_shape.begin() + begin_norm_axis); infer_context->SetShapeOrDataForValue( op->result(2), symbol::ShapeOrDataDimExprs{ @@ -1293,7 +1293,7 @@ bool RmsNormOpInferSymbolicShape( infer_context->SetShapeOrDataForValue( op->result(1), - symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_dims)}); + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); return true; }