From 2a2ca540b63f44eb286a9732338a4348e7dbeb1b Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Wed, 7 Aug 2024 14:37:17 +0800 Subject: [PATCH 1/8] Finished norm and p_norm --- .../infer_symbolic_shape/unary_infer_sym.cc | 109 ++++++++++++++++-- .../infer_symbolic_shape/unary_infer_sym.h | 4 +- paddle/phi/ops/yaml/ops.yaml | 2 + 3 files changed, 101 insertions(+), 14 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 5f73602b90e9a..66952b2985671 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1158,12 +1158,33 @@ bool MeanAllOpInferSymbolicShape( // return true; // } -// bool NormOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext *infer_context) -// { -// // pass -// return true; -// } +bool NormOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + auto x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &x_shape = x_shape_or_data.shape(); + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); + + int axis = op->attribute<pir::Int32Attribute>("axis").data(); + float epsilon = op->attribute<pir::FloatAttribute>("epsilon").data(); + bool is_test = op->attribute<pir::BoolAttribute>("is_test").data(); + + if (!is_test) { + if (axis < 0) axis += x_shape.size(); + + auto norm_shape = x_shape; + norm_shape[axis] = symbol::DimExpr(1); + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(norm_shape)}); + } + + return true; +} bool NonzeroOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { @@ -1198,12 +1219,76 @@ bool NumelOpInferSymbolicShape(pir::Operation *op, return true; } -// bool P_NormOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext -// *infer_context) { -// // pass -// return true; -// } + +bool P_NormOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + auto x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &x_shape = x_shape_or_data.shape(); + auto x_rank = x_shape.size(); + + float porder = op->attribute<pir::FloatAttribute>("porder").data(); + int axis = op->attribute<pir::Int32Attribute>("axis").data(); + float epsilon = op->attribute<pir::FloatAttribute>("epsilon").data(); + bool keepdim = op->attribute<pir::BoolAttribute>("keepdim").data(); + bool asvector = op->attribute<pir::BoolAttribute>("asvector").data(); + + PADDLE_ENFORCE_GE(axis, + -x_rank, + common::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], R is " + "the rank of Input(X). But received axis: %d, R: %d. " + "Current Input(X)'s shape is=[%s].", + axis, + x_rank, + x_shape)); + PADDLE_ENFORCE_LT(axis, + x_rank, + common::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], R is " + "the rank of Input(X). But received axis: %d, R: %d. " + "Current Input(X)'s shape is=[%s].", + axis, + x_rank, + x_shape)); + + std::vector<symbol::DimExpr> out_dim_vector; + + if (asvector) { + if (keepdim) { + for (size_t i = 0; i < x_rank; ++i) { + out_dim_vector.emplace_back(symbol::DimExpr(1)); + } + } else { + out_dim_vector = {}; + } + } else { + if (axis < 0) axis += x_rank; + + if (keepdim) { + for (size_t i = 0; i < x_rank; ++i) { + if (static_cast<int>(i) == axis) { + out_dim_vector.emplace_back(symbol::DimExpr(1)); + } else { + out_dim_vector.emplace_back(x_shape[i]); + } + } + } else { + for (size_t i = 0; i < x_rank; ++i) { + if (static_cast<int>(i) != axis) { + out_dim_vector.emplace_back(x_shape[i]); + } + } + } + } + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_dim_vector)}); + + return true; +} // bool PartialSumOpInferSymbolicShape(pir::Operation *op, // pir::InferSymbolicShapeContext diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 9ffec0c4edb10..fe09591a9f8dc 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -79,10 +79,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MeanAll) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool3DWithIndex) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Multinomial) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nanmedian) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Numel) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(P_Norm) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(P_Norm) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(PartialSum) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 319d8bea6d9cd..03555972ad86f 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3367,6 +3367,7 @@ kernel : func : norm backward : norm_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : npu_identity args : (Tensor x, int format = -1) @@ -3428,6 +3429,7 @@ kernel : func : p_norm backward : p_norm_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : pad args : (Tensor x, int[] paddings, Scalar pad_value) From 407ee474ed9e71ff241e8166244f1fa8ebb37250 Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Wed, 7 Aug 2024 15:27:44 +0800 Subject: [PATCH 2/8] Fixed name errors --- .../interface/infer_symbolic_shape/unary_infer_sym.cc | 4 ++-- .../operator/interface/infer_symbolic_shape/unary_infer_sym.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 66952b2985671..116938b6d8af6 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1220,8 +1220,8 @@ bool NumelOpInferSymbolicShape(pir::Operation *op, return true; } -bool P_NormOpInferSymbolicShape(pir::Operation *op, - pir::InferSymbolicShapeContext *infer_context) { +bool PNormOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { auto x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &x_shape = x_shape_or_data.shape(); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index fe09591a9f8dc..a34f4768fbe13 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -82,7 +82,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MeanAll) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Numel) -OP_DECLARE_INFER_SYMBOLIC_SHAPE(P_Norm) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(PNorm) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(PartialSum) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d) From 85ce4b32268e2302501f1351f86ea84b858f2bd4 Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Wed, 7 Aug 2024 17:15:56 +0800 Subject: [PATCH 3/8] Removed unused variables --- .../operator/interface/infer_symbolic_shape/unary_infer_sym.cc | 3 --- 1 file changed, 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 116938b6d8af6..6a9e3a91ed066 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1169,7 +1169,6 @@ bool NormOpInferSymbolicShape(pir::Operation *op, symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); int axis = op->attribute<pir::Int32Attribute>("axis").data(); - float epsilon = op->attribute<pir::FloatAttribute>("epsilon").data(); bool is_test = op->attribute<pir::BoolAttribute>("is_test").data(); if (!is_test) { @@ -1227,9 +1226,7 @@ bool PNormOpInferSymbolicShape(pir::Operation *op, const auto &x_shape = x_shape_or_data.shape(); auto x_rank = x_shape.size(); - float porder = op->attribute<pir::FloatAttribute>("porder").data(); int axis = op->attribute<pir::Int32Attribute>("axis").data(); - float epsilon = op->attribute<pir::FloatAttribute>("epsilon").data(); bool keepdim = op->attribute<pir::BoolAttribute>("keepdim").data(); bool asvector = op->attribute<pir::BoolAttribute>("asvector").data(); From 5ffe7905ebf20e5f0d66f271e8b3a1cbd3e5236f Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Fri, 9 Aug 2024 12:36:29 +0800 Subject: [PATCH 4/8] Updated norm and pnorm in unary_infer_sym.cc according to suggested changes --- .../infer_symbolic_shape/unary_infer_sym.cc | 79 ++++++++++--------- 1 file changed, 42 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 6a9e3a91ed066..7d4b8080c36ac 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1174,12 +1174,12 @@ bool NormOpInferSymbolicShape(pir::Operation *op, if (!is_test) { if (axis < 0) axis += x_shape.size(); - auto norm_shape = x_shape; - norm_shape[axis] = symbol::DimExpr(1); + // Directly modify x_shape at the specified axis. infer_context->SetShapeOrDataForValue( op->result(1), symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(norm_shape)}); + symbol::TensorShapeOrDataDimExprs(x_shape).SetDim( + axis, symbol::DimExpr(1))}); } return true; @@ -1221,59 +1221,64 @@ bool NumelOpInferSymbolicShape(pir::Operation *op, bool PNormOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - auto x_shape_or_data = + const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const auto &x_shape = x_shape_or_data.shape(); - auto x_rank = x_shape.size(); + int x_rank = x_shape.size(); int axis = op->attribute<pir::Int32Attribute>("axis").data(); bool keepdim = op->attribute<pir::BoolAttribute>("keepdim").data(); bool asvector = op->attribute<pir::BoolAttribute>("asvector").data(); - PADDLE_ENFORCE_GE(axis, - -x_rank, - common::errors::InvalidArgument( - "Attr(axis) value should be in range [-R, R-1], R is " - "the rank of Input(X). But received axis: %d, R: %d. " - "Current Input(X)'s shape is=[%s].", - axis, - x_rank, - x_shape)); - PADDLE_ENFORCE_LT(axis, - x_rank, - common::errors::InvalidArgument( - "Attr(axis) value should be in range [-R, R-1], R is " - "the rank of Input(X). But received axis: %d, R: %d. " - "Current Input(X)'s shape is=[%s].", - axis, - x_rank, - x_shape)); + if (axis < 0) { + axis += x_rank; + } + + PADDLE_ENFORCE_GE( + axis, + 0, + common::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], R is the rank of " + "Input(X). " + "But received axis: %d, R: %d. Current Input(X)'s shape is=[%s].", + axis, + x_rank, + x_shape)); + + PADDLE_ENFORCE_LT( + axis, + x_rank, + common::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], R is the rank of " + "Input(X). " + "But received axis: %d, R: %d. Current Input(X)'s shape is=[%s].", + axis, + x_rank, + x_shape)); - std::vector<symbol::DimExpr> out_dim_vector; + std::vector<symbol::DimExpr> out_shape; if (asvector) { if (keepdim) { - for (size_t i = 0; i < x_rank; ++i) { - out_dim_vector.emplace_back(symbol::DimExpr(1)); + for (int i = 0; i < x_rank; ++i) { + out_shape.emplace_back(symbol::DimExpr(1)); } } else { - out_dim_vector = {}; + out_shape = {}; } } else { - if (axis < 0) axis += x_rank; - if (keepdim) { - for (size_t i = 0; i < x_rank; ++i) { - if (static_cast<int>(i) == axis) { - out_dim_vector.emplace_back(symbol::DimExpr(1)); + for (int i = 0; i < x_rank; ++i) { + if (i == axis) { + out_shape.emplace_back(symbol::DimExpr(1)); } else { - out_dim_vector.emplace_back(x_shape[i]); + out_shape.emplace_back(x_shape[i]); } } } else { - for (size_t i = 0; i < x_rank; ++i) { - if (static_cast<int>(i) != axis) { - out_dim_vector.emplace_back(x_shape[i]); + for (int i = 0; i < x_rank; ++i) { + if (i != axis) { + out_shape.emplace_back(x_shape[i]); } } } @@ -1282,7 +1287,7 @@ bool PNormOpInferSymbolicShape(pir::Operation *op, infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(out_dim_vector)}); + symbol::TensorShapeOrDataDimExprs(out_shape)}); return true; } From 9919dbdec712fe80245ae94e5614a857fd55745f Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Fri, 9 Aug 2024 12:36:59 +0800 Subject: [PATCH 5/8] Removed comments --- .../operator/interface/infer_symbolic_shape/unary_infer_sym.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 7d4b8080c36ac..e876f00cc8531 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1174,7 +1174,6 @@ bool NormOpInferSymbolicShape(pir::Operation *op, if (!is_test) { if (axis < 0) axis += x_shape.size(); - // Directly modify x_shape at the specified axis. infer_context->SetShapeOrDataForValue( op->result(1), symbol::ShapeOrDataDimExprs{ From 19af8426ffb1af45dc0a314df39a81c8c7b83e93 Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Fri, 9 Aug 2024 16:28:54 +0800 Subject: [PATCH 6/8] Fixed errors returned by CI --- .../interface/infer_symbolic_shape/unary_infer_sym.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index e876f00cc8531..13f941b8d5b2d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1174,11 +1174,11 @@ bool NormOpInferSymbolicShape(pir::Operation *op, if (!is_test) { if (axis < 0) axis += x_shape.size(); + x_shape[axis] = symbol::DimExpr(1); infer_context->SetShapeOrDataForValue( op->result(1), symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(x_shape).SetDim( - axis, symbol::DimExpr(1))}); + symbol::TensorShapeOrDataDimExprs(x_shape)}); } return true; From 08663f196a4d14128a894ed62b978b0259b1af4d Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Fri, 9 Aug 2024 16:46:32 +0800 Subject: [PATCH 7/8] Put some old code back --- .../interface/infer_symbolic_shape/unary_infer_sym.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 13f941b8d5b2d..b7df8e607d0f8 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1174,11 +1174,12 @@ bool NormOpInferSymbolicShape(pir::Operation *op, if (!is_test) { if (axis < 0) axis += x_shape.size(); - x_shape[axis] = symbol::DimExpr(1); + auto norm_shape = x_shape; + norm_shape[axis] = symbol::DimExpr(1); infer_context->SetShapeOrDataForValue( op->result(1), symbol::ShapeOrDataDimExprs{ - symbol::TensorShapeOrDataDimExprs(x_shape)}); + symbol::TensorShapeOrDataDimExprs(norm_shape)}); } return true; From 595694cde0b1612dcaeffc67d16a2e51afb36206 Mon Sep 17 00:00:00 2001 From: MufanColin <76479709+MufanColin@users.noreply.github.com> Date: Fri, 9 Aug 2024 16:59:00 +0800 Subject: [PATCH 8/8] Use a bool variable to make life easier --- .../infer_symbolic_shape/unary_infer_sym.cc | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index b7df8e607d0f8..1fab7a25b1cab 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1234,20 +1234,11 @@ bool PNormOpInferSymbolicShape(pir::Operation *op, axis += x_rank; } - PADDLE_ENFORCE_GE( - axis, - 0, - common::errors::InvalidArgument( - "Attr(axis) value should be in range [-R, R-1], R is the rank of " - "Input(X). " - "But received axis: %d, R: %d. Current Input(X)'s shape is=[%s].", - axis, - x_rank, - x_shape)); + bool axis_valid = (axis >= 0) && (axis < x_rank); - PADDLE_ENFORCE_LT( - axis, - x_rank, + PADDLE_ENFORCE_EQ( + axis_valid, + true, common::errors::InvalidArgument( "Attr(axis) value should be in range [-R, R-1], R is the rank of " "Input(X). "