From 789a16ad9c7d37ebbc2f119a4bd75a6f15a819c8 Mon Sep 17 00:00:00 2001 From: lwkhahaha <124662571+lwkhahaha@users.noreply.github.com> Date: Mon, 5 Aug 2024 16:44:20 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Infer=20Symbolic=20Shape=20BUAA=20No.3?= =?UTF-8?q?5=E3=80=91Add=20equal=5Fall=20op=20(#66888)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Infer Symbolic Shape BUAA No.35】Add equal_all op * 【Infer Symbolic Shape BUAA No.35】Add equal_all op * Update binary_infer_sym.cc * Update binary_infer_sym.cc --- .../infer_symbolic_shape/binary_infer_sym.cc | 22 +++++++++++++++++++ .../infer_symbolic_shape/binary_infer_sym.h | 1 + paddle/phi/ops/yaml/ops.yaml | 1 + 3 files changed, 24 insertions(+) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 2824d36a688ec..8d8cedf447a8b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -307,6 +307,28 @@ bool EmbeddingOpInferSymbolicShape( return true; } +bool EqualAllOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &x_dims = + infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); + const auto &y_dims = + infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape(); + + PADDLE_ENFORCE_GE( + x_dims.size(), + y_dims.size(), + common::errors::InvalidArgument( + "The size of dim_y should not be greater than dim_x's.")); + + std::vector out_dims = + {}; // Adjust the dimensions as necessary + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); + + return true; +} + bool SparseWeightEmbeddingOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { PADDLE_THROW(common::errors::Unimplemented( diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h index 92e913123ee1d..0da35da694596 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h @@ -27,6 +27,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv3d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cross) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Embedding) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(EqualAll) OP_DECLARE_INFER_SYMBOLIC_SHAPE(SparseWeightEmbedding) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ExpandAs) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 457e9fea28b30..6cf58a9d08b69 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1551,6 +1551,7 @@ func : CompareAllInferMeta kernel : func : equal_all + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : erf args : (Tensor x)