diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 79512fb69ea951..d842be1cf2bd17 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -146,6 +146,24 @@ bool AsRealOpInferSymbolicShape(pir::Operation *op, return true; } +bool BipartiteMatchOpInferSymbolicShape( + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { + const auto &dist_mat_shape_or_data = + infer_context->GetShapeOrDataForValue(op->operand_source(0)); + const auto &dims = dist_mat_shape_or_data.shape(); + + PADDLE_ENFORCE_EQ( + dims.size(), + 2, + phi::errors::InvalidArgument("The rank of Input(DistMat) must be 2.")); + + infer_context->SetShapeOrDataForValue(op->result(0), dist_mat_shape_or_data); + + infer_context->SetShapeOrDataForValue(op->result(1), dist_mat_shape_or_data); + + return true; +} + bool CummaxOpInferSymbolicShape(pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { pir::Value operand_source = op->operand_source(0); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h index cf3e4926c964cf..f67b3a3d998f8b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.h @@ -25,6 +25,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argmax) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argmin) OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsComplex) OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsReal) +OP_DECLARE_INFER_SYMBOLIC_SHAPE(BipartiteMatch) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cummax) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cummin) OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumprod) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index 769c5bb3a61dfe..fdae8fe1fff781 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -561,6 +561,7 @@ kernel: func: bipartite_match data_type: dist_mat + interfaces : paddle::dialect::InferSymbolicShapeInterface - op : bitwise_and args : (Tensor x, Tensor y)