Skip to content

Commit

Permalink
【Infer Symbolic Shape BUAA No.35】Add equal_all op (#66888)
Browse files Browse the repository at this point in the history
* Infer Symbolic Shape BUAA No.35】Add equal_all op

* 【Infer Symbolic Shape BUAA No.35】Add equal_all op

* Update binary_infer_sym.cc

* Update binary_infer_sym.cc
  • Loading branch information
lwkhahaha authored Aug 5, 2024
1 parent d9cba9e commit 789a16a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,28 @@ bool EmbeddingOpInferSymbolicShape(
return true;
}

bool EqualAllOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x_dims =
infer_context->GetShapeOrDataForValue(op->operand_source(0)).shape();
const auto &y_dims =
infer_context->GetShapeOrDataForValue(op->operand_source(1)).shape();

PADDLE_ENFORCE_GE(
x_dims.size(),
y_dims.size(),
common::errors::InvalidArgument(
"The size of dim_y should not be greater than dim_x's."));

std::vector<symbol::DimExpr> out_dims =
{}; // Adjust the dimensions as necessary
infer_context->SetShapeOrDataForValue(
op->result(0),
symbol::ShapeOrDataDimExprs{symbol::TensorShapeOrDataDimExprs(out_dims)});

return true;
}

bool SparseWeightEmbeddingOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
PADDLE_THROW(common::errors::Unimplemented(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv2d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Conv3d)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cross)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Embedding)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(EqualAll)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(SparseWeightEmbedding)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ExpandAs)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(FillDiagonalTensor)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/ops/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1551,6 +1551,7 @@
func : CompareAllInferMeta
kernel :
func : equal_all
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : erf
args : (Tensor x)
Expand Down

0 comments on commit 789a16a

Please sign in to comment.