Skip to content

Commit

Permalink
【Infer Symbolic Shape BUAA No.6】Add bipartite_match (#66618)
Browse files Browse the repository at this point in the history
* add BipartiteMatch

* refix
  • Loading branch information
uanu2002 authored Jul 30, 2024
1 parent 3ba5ae6 commit 378fbd9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,24 @@ bool AsRealOpInferSymbolicShape(pir::Operation *op,
return true;
}

bool BipartiteMatchOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &dist_mat_shape_or_data =
infer_context->GetShapeOrDataForValue(op->operand_source(0));
const auto &dims = dist_mat_shape_or_data.shape();

PADDLE_ENFORCE_EQ(
dims.size(),
2,
phi::errors::InvalidArgument("The rank of Input(DistMat) must be 2."));

infer_context->SetShapeOrDataForValue(op->result(0), dist_mat_shape_or_data);

infer_context->SetShapeOrDataForValue(op->result(1), dist_mat_shape_or_data);

return true;
}

bool CummaxOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
pir::Value operand_source = op->operand_source(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argmax)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Argmin)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsComplex)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(AsReal)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(BipartiteMatch)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cummax)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cummin)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumprod)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,7 @@
kernel:
func: bipartite_match
data_type: dist_mat
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : bitwise_and
args : (Tensor x, Tensor y)
Expand Down

0 comments on commit 378fbd9

Please sign in to comment.