diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 5f73602b90e9a..1fab7a25b1cab 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1158,12 +1158,32 @@ bool MeanAllOpInferSymbolicShape( // return true; // } -// bool NormOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext *infer_context) -// { -// // pass -// return true; -// } +bool NormOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + auto x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &x_shape = x_shape_or_data.shape(); + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(x_shape)}); + + int axis = op->attribute("axis").data(); + bool is_test = op->attribute("is_test").data(); + + if (!is_test) { + if (axis < 0) axis += x_shape.size(); + + auto norm_shape = x_shape; + norm_shape[axis] = symbol::DimExpr(1); + infer_context->SetShapeOrDataForValue( + op->result(1), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(norm_shape)}); + } + + return true; +} bool NonzeroOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { @@ -1198,12 +1218,70 @@ bool NumelOpInferSymbolicShape(pir::Operation *op, return true; } -// bool P_NormOpInferSymbolicShape(pir::Operation *op, -// pir::InferSymbolicShapeContext -// *infer_context) { -// // pass -// return true; -// } + +bool PNormOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &x_shape = x_shape_or_data.shape(); + int x_rank = x_shape.size(); + + int axis = op->attribute("axis").data(); + bool keepdim = op->attribute("keepdim").data(); + bool asvector = op->attribute("asvector").data(); + + if (axis < 0) { + axis += x_rank; + } + + bool axis_valid = (axis >= 0) && (axis < x_rank); + + PADDLE_ENFORCE_EQ( + axis_valid, + true, + common::errors::InvalidArgument( + "Attr(axis) value should be in range [-R, R-1], R is the rank of " + "Input(X). " + "But received axis: %d, R: %d. Current Input(X)'s shape is=[%s].", + axis, + x_rank, + x_shape)); + + std::vector out_shape; + + if (asvector) { + if (keepdim) { + for (int i = 0; i < x_rank; ++i) { + out_shape.emplace_back(symbol::DimExpr(1)); + } + } else { + out_shape = {}; + } + } else { + if (keepdim) { + for (int i = 0; i < x_rank; ++i) { + if (i == axis) { + out_shape.emplace_back(symbol::DimExpr(1)); + } else { + out_shape.emplace_back(x_shape[i]); + } + } + } else { + for (int i = 0; i < x_rank; ++i) { + if (i != axis) { + out_shape.emplace_back(x_shape[i]); + } + } + } + } + + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{ + symbol::TensorShapeOrDataDimExprs(out_shape)}); + + return true; +} // bool PartialSumOpInferSymbolicShape(pir::Operation *op, // pir::InferSymbolicShapeContext diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 9ffec0c4edb10..a34f4768fbe13 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -79,10 +79,10 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(MeanAll) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(MaxPool3DWithIndex) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Multinomial) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nanmedian) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Norm) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Nonzero) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Numel) -// OP_DECLARE_INFER_SYMBOLIC_SHAPE(P_Norm) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(PNorm) // OP_DECLARE_INFER_SYMBOLIC_SHAPE(PartialSum) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Pad3d) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 319d8bea6d9cd..03555972ad86f 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -3367,6 +3367,7 @@ kernel : func : norm backward : norm_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : npu_identity args : (Tensor x, int format = -1) @@ -3428,6 +3429,7 @@ kernel : func : p_norm backward : p_norm_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : pad args : (Tensor x, int[] paddings, Scalar pad_value)