diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index ab60df5859e9f..13708d23399ba 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -259,6 +259,28 @@ bool EmbeddingOpInferSymbolicShape( return true; } +bool EqualAllOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &x_dims = + infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); + const auto &y_dims = + infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape(); + + PADDLE_ENFORCE_GE( + x_dims.size(), + y_dims.size(), + common::errors::InvalidArgument( + "The size of dim_y should not be greater than dim_x's.")); + + std::vector out_dims = + {}; // Adjust the dimensions as necessary + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); + + return true; +} + bool SparseWeightEmbeddingOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { PADDLE_THROW(phi::errors::Unimplemented( diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h index f3ac5ba69c5d9..c038e33bb3336 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h @@ -25,6 +25,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(BceLoss_) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv3d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Embedding) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(EqualAll) OP_DECLARE_INFER_SYMBOLIC_SHAPE(SparseWeightEmbedding) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ExpandAs) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 1126966ac60a1..dbd701152d428 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1540,6 +1540,7 @@ func : CompareAllInferMeta kernel : func : equal_all + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : erf args : (Tensor x)