From 0c9a1934a956e60da323b09cb008b41f51bf619d Mon Sep 17 00:00:00 2001 From: lwkhahaha Date: Thu, 1 Aug 2024 10:39:15 +0800 Subject: [PATCH 1/4] =?UTF-8?q?Infer=20Symbolic=20Shape=20BUAA=20No.35?= =?UTF-8?q?=E3=80=91Add=20equal=5Fall=20op?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../infer_symbolic_shape/binary_infer_sym.cc | 28 +++++++++++++++++++ .../infer_symbolic_shape/binary_infer_sym.h | 1 + paddle/phi/ops/yaml/ops.yaml | 1 + 3 files changed, 30 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index ab60df5859e9f..4166ec7797e7f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -259,6 +259,34 @@ bool EmbeddingOpInferSymbolicShape( return true; } +bool EqualAllOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + // Obtain the dimensions of x and y + auto x_dims = + infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); + auto y_dims = + infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape(); + + // Ensure that the dimensions of x are not smaller than those of y + if (x_dims.size() < y_dims.size()) { + throw phi::errors::InvalidArgument( + "The size of y_dims should not be greater than x_dims."); + } + + // Set the output dimensions + std::vector out_dims = { + symbol::DimExpr()}; // Adjust the dimensions as necessary + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); + + // Share the LOD (Level of Detail) from x to out + infer_context->SetShapeOrDataForValue( + op->operand_source(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); + return true; +} + bool SparseWeightEmbeddingOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { PADDLE_THROW(phi::errors::Unimplemented( diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h index f3ac5ba69c5d9..c038e33bb3336 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h @@ -25,6 +25,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv3d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Embedding) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(EqualAll) OP_DECLARE_INFER_SYMBOLIC_SHAPE(SparseWeightEmbedding) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ExpandAs) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 1126966ac60a1..dbd701152d428 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1540,6 +1540,7 @@ func : CompareAllInferMeta kernel : func : equal_all + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : erf args : (Tensor x) From 534f0648708cbfc6cfb588757cae6d59546450eb Mon Sep 17 00:00:00 2001 From: lwkhahaha Date: Thu, 1 Aug 2024 17:49:34 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E3=80=90Infer=20Symbolic=20Shape=20BUAA=20?= =?UTF-8?q?No.35=E3=80=91Add=20equal=5Fall=20op?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../infer_symbolic_shape/binary_infer_sym.cc | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 4166ec7797e7f..388a05267f576 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -261,29 +261,23 @@ bool EmbeddingOpInferSymbolicShape( bool EqualAllOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { - // Obtain the dimensions of x and y - auto x_dims = + const auto &x_dims = infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); - auto y_dims = + const auto &y_dims = infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape(); - // Ensure that the dimensions of x are not smaller than those of y - if (x_dims.size() < y_dims.size()) { - throw phi::errors::InvalidArgument( - "The size of y_dims should not be greater than x_dims."); - } + PADDLE_ENFORCE_EQ( + x_dims.size(), + y_dime.size(), + common::errors::InvalidArgument( + "The size of dim_y should not be greater than dim_x's.")); - // Set the output dimensions - std::vector out_dims = { - symbol::DimExpr()}; // Adjust the dimensions as necessary + std::vector out_dims = + {}; // Adjust the dimensions as necessary infer_context->SetShapeOrDataForValue( op->result(0), symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); - // Share the LOD (Level of Detail) from x to out - infer_context->SetShapeOrDataForValue( - op->operand_source(0), - symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); return true; } From 2224c5891996d94d0f9214fe041cf8352c0e4b7f Mon Sep 17 00:00:00 2001 From: lwkhahaha <124662571+lwkhahaha@users.noreply.github.com> Date: Thu, 1 Aug 2024 18:34:47 +0800 Subject: [PATCH 3/4] Update binary_infer_sym.cc --- .../operator/interface/infer_symbolic_shape/binary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 388a05267f576..4df59ba0c35b4 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -268,7 +268,7 @@ bool EqualAllOpInferSymbolicShape( PADDLE_ENFORCE_EQ( x_dims.size(), - y_dime.size(), + y_dims.size(), common::errors::InvalidArgument( "The size of dim_y should not be greater than dim_x's.")); From 8347290abcdc480354e175df07ee93978c2c316a Mon Sep 17 00:00:00 2001 From: lwkhahaha <124662571+lwkhahaha@users.noreply.github.com> Date: Fri, 2 Aug 2024 13:01:37 +0800 Subject: [PATCH 4/4] Update binary_infer_sym.cc --- .../operator/interface/infer_symbolic_shape/binary_infer_sym.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 4df59ba0c35b4..13708d23399ba 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -266,7 +266,7 @@ bool EqualAllOpInferSymbolicShape( const auto &y_dims = infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape(); - PADDLE_ENFORCE_EQ( + PADDLE_ENFORCE_GE( x_dims.size(), y_dims.size(), common::errors::InvalidArgument(