diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc index 2824d36a688ec..8d8cedf447a8b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.cc @@ -307,6 +307,28 @@ bool EmbeddingOpInferSymbolicShape( return true; } +bool EqualAllOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &x_dims = + infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape(); + const auto &y_dims = + infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape(); + + PADDLE_ENFORCE_GE( + x_dims.size(), + y_dims.size(), + common::errors::InvalidArgument( + "The size of dim_y should not be greater than dim_x's.")); + + std::vector out_dims = + {}; // Adjust the dimensions as necessary + infer_context->SetShapeOrDataForValue( + op->result(0), + symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)}); + + return true; +} + bool SparseWeightEmbeddingOpInferSymbolicShape( pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { PADDLE_THROW(common::errors::Unimplemented( diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h index 92e913123ee1d..0da35da694596 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/binary_infer_sym.h @@ -27,6 +27,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv3d) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cross) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Embedding) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(EqualAll) OP_DECLARE_INFER_SYMBOLIC_SHAPE(SparseWeightEmbedding) OP_DECLARE_INFER_SYMBOLIC_SHAPE(ExpandAs) OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 457e9fea28b30..6cf58a9d08b69 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -1551,6 +1551,7 @@ func : CompareAllInferMeta kernel : func : equal_all + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : erf args : (Tensor x)