diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 3b1053e14481a8..dc3db6fca672b2 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -146,6 +146,24 @@ bool AsRealOpInferSymbolicShape(pir::Operation *op, return true; } +bool BipartiteMatchOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &dist_mat_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &dims = dist_mat_shape_or_data.shape(); + + PADDLE_ENFORCE_EQ( + dims.size(), + 2, + phi::errors::InvalidArgument("The rank of Input(DistMat) must be 2.")); + + infer_context->SetShapeOrDataForValue(op->result(0), dist_mat_shape_or_data); + + infer_context->SetShapeOrDataForValue(op->result(1), dist_mat_shape_or_data); + + return true; +} + bool CummaxOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index 48df2c7a8ec049..bfb00bbef9b903 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -25,6 +25,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argmax) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argmin) OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsComplex) OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsReal) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BipartiteMatch) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cummax) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cummin) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumprod) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index ed9c70825c9c35..ea19062d5e6d1f 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -562,6 +562,7 @@ kernel: func: bipartite_match data_type: dist_mat + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : bitwise_and args : (Tensor x, Tensor y)