From b52b08beecb685c216ef6e3f007ee42a4c72b501 Mon Sep 17 00:00:00 2001 From: enkilee Date: Mon, 9 Sep 2024 16:40:18 +0800 Subject: [PATCH 1/2] fix --- .../infer_symbolic_shape/unary_infer_sym.cc | 32 +++++++++++++++++++ .../infer_symbolic_shape/unary_infer_sym.h | 1 + .../phi/ops/yaml/inconsistent/static_ops.yaml | 1 + paddle/phi/ops/yaml/legacy/static_ops.yaml | 1 + 4 files changed, 35 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index c43a89f05401d..345e778900203 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1585,6 +1585,38 @@ bool LogsumexpOpInferSymbolicShape( return details::ReduceInferDim(op, infer_context, axis, keepdim, reduce_all); } +bool LrnInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { + const auto &x_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const std::vector &x_shape = x_shape_or_data.shape(); + int x_size = x_shape.size(); + PADDLE_ENFORCE_EQ( + x_size, + 4, + common::errors::InvalidArgument("Input(input) rank should be 4, " + "but received input rank (%d) != 4", + x_size)); + int n_value = op->attribute("n").data(); + PADDLE_ENFORCE_GT( + n_value, + 0UL, + common::errors::InvalidArgument("Argument(n) should be positive, " + "but received n(%d) not greater than 0", + n_value)); + PADDLE_ENFORCE_EQ( + n_value % 2, + 1UL, + common::errors::InvalidArgument("Argument(n) should be odd value, " + "but received n(%d) is not an odd value", + n_value)); + infer_context->SetShapeOrDataForValue( + op->result(0), symbol::TensorShapeOrDataDimExprs(x_shape)); + infer_context->SetShapeOrDataForValue( + op->result(1), symbol::TensorShapeOrDataDimExprs(x_shape)); + return true; +} + bool LuOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 869d8ca4ab879..53b1011c8140e 100755 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -85,6 +85,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(L1Norm_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(LpPool2d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logsumexp) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lrn) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lu) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Lu_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Mode) diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index b7736b2e74e29..977a42702a5bb 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -526,6 +526,7 @@ func: lrn data_type: x backward: lrn_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : matmul args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false) diff --git a/paddle/phi/ops/yaml/legacy/static_ops.yaml b/paddle/phi/ops/yaml/legacy/static_ops.yaml index fc0e89d1db7d8..34230c7a4dd21 100755 --- a/paddle/phi/ops/yaml/legacy/static_ops.yaml +++ b/paddle/phi/ops/yaml/legacy/static_ops.yaml @@ -519,6 +519,7 @@ func: lrn data_type: x backward: lrn_grad + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : matmul args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false) From e18ae5c7113b9200d58ee294cb66a384aae0029d Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 11 Sep 2024 08:39:50 +0800 Subject: [PATCH 2/2] fix --- .../interface/infer_symbolic_shape/unary_infer_sym.cc | 4 ++-- paddle/phi/ops/yaml/legacy/static_ops.yaml | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 6507ecb11e55e..3c707eb23dae9 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -1585,8 +1585,8 @@ bool LogsumexpOpInferSymbolicShape( return details::ReduceInferDim(op, infer_context, axis, keepdim, reduce_all); } -bool LrnInferSymbolicShape(pir::Operation *op, - pir::InferSymbolicShapeContext *infer_context) { +bool LrnOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { const auto &x_shape_or_data = infer_context->GetShapeOrDataForValue(op->operand_source(0)); const std::vector &x_shape = x_shape_or_data.shape(); diff --git a/paddle/phi/ops/yaml/legacy/static_ops.yaml b/paddle/phi/ops/yaml/legacy/static_ops.yaml index 34230c7a4dd21..fc0e89d1db7d8 100755 --- a/paddle/phi/ops/yaml/legacy/static_ops.yaml +++ b/paddle/phi/ops/yaml/legacy/static_ops.yaml @@ -519,7 +519,6 @@ func: lrn data_type: x backward: lrn_grad - interfaces : paddle::dialect::InferSymbolicShapeInterface - op : matmul args : (Tensor x, Tensor y, bool transpose_x = false, bool transpose_y = false)